You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

668 lines
21 KiB

  1. import tempfile
  2. import typing
  3. from collections import namedtuple
  4. from collections.abc import Sequence
  5. from shlex import shlex
  6. from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
  7. from starlette.concurrency import run_in_threadpool
  8. from starlette.types import Scope
  9. Address = namedtuple("Address", ["host", "port"])
  10. class URL:
  11. def __init__(
  12. self, url: str = "", scope: Scope = None, **components: typing.Any
  13. ) -> None:
  14. if scope is not None:
  15. assert not url, 'Cannot set both "url" and "scope".'
  16. assert not components, 'Cannot set both "scope" and "**components".'
  17. scheme = scope.get("scheme", "http")
  18. server = scope.get("server", None)
  19. path = scope.get("root_path", "") + scope["path"]
  20. query_string = scope.get("query_string", b"")
  21. host_header = None
  22. for key, value in scope["headers"]:
  23. if key == b"host":
  24. host_header = value.decode("latin-1")
  25. break
  26. if host_header is not None:
  27. url = f"{scheme}://{host_header}{path}"
  28. elif server is None:
  29. url = path
  30. else:
  31. host, port = server
  32. default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
  33. if port == default_port:
  34. url = f"{scheme}://{host}{path}"
  35. else:
  36. url = f"{scheme}://{host}:{port}{path}"
  37. if query_string:
  38. url += "?" + query_string.decode()
  39. elif components:
  40. assert not url, 'Cannot set both "url" and "**components".'
  41. url = URL("").replace(**components).components.geturl()
  42. self._url = url
  43. @property
  44. def components(self) -> SplitResult:
  45. if not hasattr(self, "_components"):
  46. self._components = urlsplit(self._url)
  47. return self._components
  48. @property
  49. def scheme(self) -> str:
  50. return self.components.scheme
  51. @property
  52. def netloc(self) -> str:
  53. return self.components.netloc
  54. @property
  55. def path(self) -> str:
  56. return self.components.path
  57. @property
  58. def query(self) -> str:
  59. return self.components.query
  60. @property
  61. def fragment(self) -> str:
  62. return self.components.fragment
  63. @property
  64. def username(self) -> typing.Union[None, str]:
  65. return self.components.username
  66. @property
  67. def password(self) -> typing.Union[None, str]:
  68. return self.components.password
  69. @property
  70. def hostname(self) -> typing.Union[None, str]:
  71. return self.components.hostname
  72. @property
  73. def port(self) -> typing.Optional[int]:
  74. return self.components.port
  75. @property
  76. def is_secure(self) -> bool:
  77. return self.scheme in ("https", "wss")
  78. def replace(self, **kwargs: typing.Any) -> "URL":
  79. if (
  80. "username" in kwargs
  81. or "password" in kwargs
  82. or "hostname" in kwargs
  83. or "port" in kwargs
  84. ):
  85. hostname = kwargs.pop("hostname", self.hostname)
  86. port = kwargs.pop("port", self.port)
  87. username = kwargs.pop("username", self.username)
  88. password = kwargs.pop("password", self.password)
  89. netloc = hostname
  90. if port is not None:
  91. netloc += f":{port}"
  92. if username is not None:
  93. userpass = username
  94. if password is not None:
  95. userpass += f":{password}"
  96. netloc = f"{userpass}@{netloc}"
  97. kwargs["netloc"] = netloc
  98. components = self.components._replace(**kwargs)
  99. return self.__class__(components.geturl())
  100. def include_query_params(self, **kwargs: typing.Any) -> "URL":
  101. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  102. params.update({str(key): str(value) for key, value in kwargs.items()})
  103. query = urlencode(params.multi_items())
  104. return self.replace(query=query)
  105. def replace_query_params(self, **kwargs: typing.Any) -> "URL":
  106. query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
  107. return self.replace(query=query)
  108. def remove_query_params(
  109. self, keys: typing.Union[str, typing.Sequence[str]]
  110. ) -> "URL":
  111. if isinstance(keys, str):
  112. keys = [keys]
  113. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  114. for key in keys:
  115. params.pop(key, None)
  116. query = urlencode(params.multi_items())
  117. return self.replace(query=query)
  118. def __eq__(self, other: typing.Any) -> bool:
  119. return str(self) == str(other)
  120. def __str__(self) -> str:
  121. return self._url
  122. def __repr__(self) -> str:
  123. url = str(self)
  124. if self.password:
  125. url = str(self.replace(password="********"))
  126. return f"{self.__class__.__name__}({repr(url)})"
  127. class URLPath(str):
  128. """
  129. A URL path string that may also hold an associated protocol and/or host.
  130. Used by the routing to return `url_path_for` matches.
  131. """
  132. def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
  133. assert protocol in ("http", "websocket", "")
  134. return str.__new__(cls, path) # type: ignore
  135. def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
  136. self.protocol = protocol
  137. self.host = host
  138. def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
  139. if isinstance(base_url, str):
  140. base_url = URL(base_url)
  141. if self.protocol:
  142. scheme = {
  143. "http": {True: "https", False: "http"},
  144. "websocket": {True: "wss", False: "ws"},
  145. }[self.protocol][base_url.is_secure]
  146. else:
  147. scheme = base_url.scheme
  148. netloc = self.host or base_url.netloc
  149. path = base_url.path.rstrip("/") + str(self)
  150. return str(URL(scheme=scheme, netloc=netloc, path=path))
  151. class Secret:
  152. """
  153. Holds a string value that should not be revealed in tracebacks etc.
  154. You should cast the value to `str` at the point it is required.
  155. """
  156. def __init__(self, value: str):
  157. self._value = value
  158. def __repr__(self) -> str:
  159. class_name = self.__class__.__name__
  160. return f"{class_name}('**********')"
  161. def __str__(self) -> str:
  162. return self._value
  163. class CommaSeparatedStrings(Sequence):
  164. def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
  165. if isinstance(value, str):
  166. splitter = shlex(value, posix=True)
  167. splitter.whitespace = ","
  168. splitter.whitespace_split = True
  169. self._items = [item.strip() for item in splitter]
  170. else:
  171. self._items = list(value)
  172. def __len__(self) -> int:
  173. return len(self._items)
  174. def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
  175. return self._items[index]
  176. def __iter__(self) -> typing.Iterator[str]:
  177. return iter(self._items)
  178. def __repr__(self) -> str:
  179. class_name = self.__class__.__name__
  180. items = [item for item in self]
  181. return f"{class_name}({items!r})"
  182. def __str__(self) -> str:
  183. return ", ".join(repr(item) for item in self)
  184. class ImmutableMultiDict(typing.Mapping):
  185. def __init__(
  186. self,
  187. *args: typing.Union[
  188. "ImmutableMultiDict",
  189. typing.Mapping,
  190. typing.List[typing.Tuple[typing.Any, typing.Any]],
  191. ],
  192. **kwargs: typing.Any,
  193. ) -> None:
  194. assert len(args) < 2, "Too many arguments."
  195. value = args[0] if args else []
  196. if kwargs:
  197. value = (
  198. ImmutableMultiDict(value).multi_items()
  199. + ImmutableMultiDict(kwargs).multi_items()
  200. )
  201. if not value:
  202. _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
  203. elif hasattr(value, "multi_items"):
  204. value = typing.cast(ImmutableMultiDict, value)
  205. _items = list(value.multi_items())
  206. elif hasattr(value, "items"):
  207. value = typing.cast(typing.Mapping, value)
  208. _items = list(value.items())
  209. else:
  210. value = typing.cast(
  211. typing.List[typing.Tuple[typing.Any, typing.Any]], value
  212. )
  213. _items = list(value)
  214. self._dict = {k: v for k, v in _items}
  215. self._list = _items
  216. def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
  217. return [item_value for item_key, item_value in self._list if item_key == key]
  218. def keys(self) -> typing.KeysView:
  219. return self._dict.keys()
  220. def values(self) -> typing.ValuesView:
  221. return self._dict.values()
  222. def items(self) -> typing.ItemsView:
  223. return self._dict.items()
  224. def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
  225. return list(self._list)
  226. def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  227. if key in self._dict:
  228. return self._dict[key]
  229. return default
  230. def __getitem__(self, key: typing.Any) -> str:
  231. return self._dict[key]
  232. def __contains__(self, key: typing.Any) -> bool:
  233. return key in self._dict
  234. def __iter__(self) -> typing.Iterator[typing.Any]:
  235. return iter(self.keys())
  236. def __len__(self) -> int:
  237. return len(self._dict)
  238. def __eq__(self, other: typing.Any) -> bool:
  239. if not isinstance(other, self.__class__):
  240. return False
  241. return sorted(self._list) == sorted(other._list)
  242. def __repr__(self) -> str:
  243. class_name = self.__class__.__name__
  244. items = self.multi_items()
  245. return f"{class_name}({items!r})"
  246. class MultiDict(ImmutableMultiDict):
  247. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  248. self.setlist(key, [value])
  249. def __delitem__(self, key: typing.Any) -> None:
  250. self._list = [(k, v) for k, v in self._list if k != key]
  251. del self._dict[key]
  252. def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  253. self._list = [(k, v) for k, v in self._list if k != key]
  254. return self._dict.pop(key, default)
  255. def popitem(self) -> typing.Tuple:
  256. key, value = self._dict.popitem()
  257. self._list = [(k, v) for k, v in self._list if k != key]
  258. return key, value
  259. def poplist(self, key: typing.Any) -> typing.List:
  260. values = [v for k, v in self._list if k == key]
  261. self.pop(key)
  262. return values
  263. def clear(self) -> None:
  264. self._dict.clear()
  265. self._list.clear()
  266. def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  267. if key not in self:
  268. self._dict[key] = default
  269. self._list.append((key, default))
  270. return self[key]
  271. def setlist(self, key: typing.Any, values: typing.List) -> None:
  272. if not values:
  273. self.pop(key, None)
  274. else:
  275. existing_items = [(k, v) for (k, v) in self._list if k != key]
  276. self._list = existing_items + [(key, value) for value in values]
  277. self._dict[key] = values[-1]
  278. def append(self, key: typing.Any, value: typing.Any) -> None:
  279. self._list.append((key, value))
  280. self._dict[key] = value
  281. def update(
  282. self,
  283. *args: typing.Union[
  284. "MultiDict",
  285. typing.Mapping,
  286. typing.List[typing.Tuple[typing.Any, typing.Any]],
  287. ],
  288. **kwargs: typing.Any,
  289. ) -> None:
  290. value = MultiDict(*args, **kwargs)
  291. existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
  292. self._list = existing_items + value.multi_items()
  293. self._dict.update(value)
  294. class QueryParams(ImmutableMultiDict):
  295. """
  296. An immutable multidict.
  297. """
  298. def __init__(
  299. self,
  300. *args: typing.Union[
  301. "ImmutableMultiDict",
  302. typing.Mapping,
  303. typing.List[typing.Tuple[typing.Any, typing.Any]],
  304. str,
  305. bytes,
  306. ],
  307. **kwargs: typing.Any,
  308. ) -> None:
  309. assert len(args) < 2, "Too many arguments."
  310. value = args[0] if args else []
  311. if isinstance(value, str):
  312. super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
  313. elif isinstance(value, bytes):
  314. super().__init__(
  315. parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
  316. )
  317. else:
  318. super().__init__(*args, **kwargs) # type: ignore
  319. self._list = [(str(k), str(v)) for k, v in self._list]
  320. self._dict = {str(k): str(v) for k, v in self._dict.items()}
  321. def __str__(self) -> str:
  322. return urlencode(self._list)
  323. def __repr__(self) -> str:
  324. class_name = self.__class__.__name__
  325. query_string = str(self)
  326. return f"{class_name}({query_string!r})"
  327. class UploadFile:
  328. """
  329. An uploaded file included as part of the request data.
  330. """
  331. spool_max_size = 1024 * 1024
  332. def __init__(
  333. self, filename: str, file: typing.IO = None, content_type: str = ""
  334. ) -> None:
  335. self.filename = filename
  336. self.content_type = content_type
  337. if file is None:
  338. file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
  339. self.file = file
  340. @property
  341. def _in_memory(self) -> bool:
  342. rolled_to_disk = getattr(self.file, "_rolled", True)
  343. return not rolled_to_disk
  344. async def write(self, data: typing.Union[bytes, str]) -> None:
  345. if self._in_memory:
  346. self.file.write(data) # type: ignore
  347. else:
  348. await run_in_threadpool(self.file.write, data)
  349. async def read(self, size: int = -1) -> typing.Union[bytes, str]:
  350. if self._in_memory:
  351. return self.file.read(size)
  352. return await run_in_threadpool(self.file.read, size)
  353. async def seek(self, offset: int) -> None:
  354. if self._in_memory:
  355. self.file.seek(offset)
  356. else:
  357. await run_in_threadpool(self.file.seek, offset)
  358. async def close(self) -> None:
  359. if self._in_memory:
  360. self.file.close()
  361. else:
  362. await run_in_threadpool(self.file.close)
  363. class FormData(ImmutableMultiDict):
  364. """
  365. An immutable multidict, containing both file uploads and text input.
  366. """
  367. def __init__(
  368. self,
  369. *args: typing.Union[
  370. "FormData",
  371. typing.Mapping[str, typing.Union[str, UploadFile]],
  372. typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
  373. ],
  374. **kwargs: typing.Union[str, UploadFile],
  375. ) -> None:
  376. super().__init__(*args, **kwargs)
  377. async def close(self) -> None:
  378. for key, value in self.multi_items():
  379. if isinstance(value, UploadFile):
  380. await value.close()
  381. class Headers(typing.Mapping[str, str]):
  382. """
  383. An immutable, case-insensitive multidict.
  384. """
  385. def __init__(
  386. self,
  387. headers: typing.Mapping[str, str] = None,
  388. raw: typing.List[typing.Tuple[bytes, bytes]] = None,
  389. scope: Scope = None,
  390. ) -> None:
  391. self._list: typing.List[typing.Tuple[bytes, bytes]] = []
  392. if headers is not None:
  393. assert raw is None, 'Cannot set both "headers" and "raw".'
  394. assert scope is None, 'Cannot set both "headers" and "scope".'
  395. self._list = [
  396. (key.lower().encode("latin-1"), value.encode("latin-1"))
  397. for key, value in headers.items()
  398. ]
  399. elif raw is not None:
  400. assert scope is None, 'Cannot set both "raw" and "scope".'
  401. self._list = raw
  402. elif scope is not None:
  403. self._list = scope["headers"]
  404. @property
  405. def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
  406. return list(self._list)
  407. def keys(self) -> typing.List[str]: # type: ignore
  408. return [key.decode("latin-1") for key, value in self._list]
  409. def values(self) -> typing.List[str]: # type: ignore
  410. return [value.decode("latin-1") for key, value in self._list]
  411. def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
  412. return [
  413. (key.decode("latin-1"), value.decode("latin-1"))
  414. for key, value in self._list
  415. ]
  416. def get(self, key: str, default: typing.Any = None) -> typing.Any:
  417. try:
  418. return self[key]
  419. except KeyError:
  420. return default
  421. def getlist(self, key: str) -> typing.List[str]:
  422. get_header_key = key.lower().encode("latin-1")
  423. return [
  424. item_value.decode("latin-1")
  425. for item_key, item_value in self._list
  426. if item_key == get_header_key
  427. ]
  428. def mutablecopy(self) -> "MutableHeaders":
  429. return MutableHeaders(raw=self._list[:])
  430. def __getitem__(self, key: str) -> str:
  431. get_header_key = key.lower().encode("latin-1")
  432. for header_key, header_value in self._list:
  433. if header_key == get_header_key:
  434. return header_value.decode("latin-1")
  435. raise KeyError(key)
  436. def __contains__(self, key: typing.Any) -> bool:
  437. get_header_key = key.lower().encode("latin-1")
  438. for header_key, header_value in self._list:
  439. if header_key == get_header_key:
  440. return True
  441. return False
  442. def __iter__(self) -> typing.Iterator[typing.Any]:
  443. return iter(self.keys())
  444. def __len__(self) -> int:
  445. return len(self._list)
  446. def __eq__(self, other: typing.Any) -> bool:
  447. if not isinstance(other, Headers):
  448. return False
  449. return sorted(self._list) == sorted(other._list)
  450. def __repr__(self) -> str:
  451. class_name = self.__class__.__name__
  452. as_dict = dict(self.items())
  453. if len(as_dict) == len(self):
  454. return f"{class_name}({as_dict!r})"
  455. return f"{class_name}(raw={self.raw!r})"
  456. class MutableHeaders(Headers):
  457. def __setitem__(self, key: str, value: str) -> None:
  458. """
  459. Set the header `key` to `value`, removing any duplicate entries.
  460. Retains insertion order.
  461. """
  462. set_key = key.lower().encode("latin-1")
  463. set_value = value.encode("latin-1")
  464. found_indexes = []
  465. for idx, (item_key, item_value) in enumerate(self._list):
  466. if item_key == set_key:
  467. found_indexes.append(idx)
  468. for idx in reversed(found_indexes[1:]):
  469. del self._list[idx]
  470. if found_indexes:
  471. idx = found_indexes[0]
  472. self._list[idx] = (set_key, set_value)
  473. else:
  474. self._list.append((set_key, set_value))
  475. def __delitem__(self, key: str) -> None:
  476. """
  477. Remove the header `key`.
  478. """
  479. del_key = key.lower().encode("latin-1")
  480. pop_indexes = []
  481. for idx, (item_key, item_value) in enumerate(self._list):
  482. if item_key == del_key:
  483. pop_indexes.append(idx)
  484. for idx in reversed(pop_indexes):
  485. del self._list[idx]
  486. @property
  487. def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
  488. return self._list
  489. def setdefault(self, key: str, value: str) -> str:
  490. """
  491. If the header `key` does not exist, then set it to `value`.
  492. Returns the header value.
  493. """
  494. set_key = key.lower().encode("latin-1")
  495. set_value = value.encode("latin-1")
  496. for idx, (item_key, item_value) in enumerate(self._list):
  497. if item_key == set_key:
  498. return item_value.decode("latin-1")
  499. self._list.append((set_key, set_value))
  500. return value
  501. def update(self, other: dict) -> None:
  502. for key, val in other.items():
  503. self[key] = val
  504. def append(self, key: str, value: str) -> None:
  505. """
  506. Append a header, preserving any duplicate entries.
  507. """
  508. append_key = key.lower().encode("latin-1")
  509. append_value = value.encode("latin-1")
  510. self._list.append((append_key, append_value))
  511. def add_vary_header(self, vary: str) -> None:
  512. existing = self.get("vary")
  513. if existing is not None:
  514. vary = ", ".join([existing, vary])
  515. self["vary"] = vary
  516. class State:
  517. """
  518. An object that can be used to store arbitrary state.
  519. Used for `request.state` and `app.state`.
  520. """
  521. def __init__(self, state: typing.Dict = None):
  522. if state is None:
  523. state = {}
  524. super().__setattr__("_state", state)
  525. def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
  526. self._state[key] = value
  527. def __getattr__(self, key: typing.Any) -> typing.Any:
  528. try:
  529. return self._state[key]
  530. except KeyError:
  531. message = "'{}' object has no attribute '{}'"
  532. raise AttributeError(message.format(self.__class__.__name__, key))
  533. def __delattr__(self, key: typing.Any) -> None:
  534. del self._state[key]