Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.
 
 
 
 

693 Zeilen
22 KiB

  1. from __future__ import annotations
  2. from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
  3. from shlex import shlex
  4. from typing import (
  5. Any,
  6. BinaryIO,
  7. NamedTuple,
  8. TypeVar,
  9. Union,
  10. cast,
  11. )
  12. from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
  13. from starlette.concurrency import run_in_threadpool
  14. from starlette.types import Scope
  15. class Address(NamedTuple):
  16. host: str
  17. port: int
  18. _KeyType = TypeVar("_KeyType")
  19. # Mapping keys are invariant but their values are covariant since
  20. # you can only read them
  21. # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
  22. _CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
  23. class URL:
  24. def __init__(
  25. self,
  26. url: str = "",
  27. scope: Scope | None = None,
  28. **components: Any,
  29. ) -> None:
  30. if scope is not None:
  31. assert not url, 'Cannot set both "url" and "scope".'
  32. assert not components, 'Cannot set both "scope" and "**components".'
  33. scheme = scope.get("scheme", "http")
  34. server = scope.get("server", None)
  35. path = scope["path"]
  36. query_string = scope.get("query_string", b"")
  37. host_header = None
  38. for key, value in scope["headers"]:
  39. if key == b"host":
  40. host_header = value.decode("latin-1")
  41. break
  42. if host_header is not None:
  43. url = f"{scheme}://{host_header}{path}"
  44. elif server is None:
  45. url = path
  46. else:
  47. host, port = server
  48. default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
  49. if port == default_port:
  50. url = f"{scheme}://{host}{path}"
  51. else:
  52. url = f"{scheme}://{host}:{port}{path}"
  53. if query_string:
  54. url += "?" + query_string.decode()
  55. elif components:
  56. assert not url, 'Cannot set both "url" and "**components".'
  57. url = URL("").replace(**components).components.geturl()
  58. self._url = url
  59. @property
  60. def components(self) -> SplitResult:
  61. if not hasattr(self, "_components"):
  62. self._components = urlsplit(self._url)
  63. return self._components
  64. @property
  65. def scheme(self) -> str:
  66. return self.components.scheme
  67. @property
  68. def netloc(self) -> str:
  69. return self.components.netloc
  70. @property
  71. def path(self) -> str:
  72. return self.components.path
  73. @property
  74. def query(self) -> str:
  75. return self.components.query
  76. @property
  77. def fragment(self) -> str:
  78. return self.components.fragment
  79. @property
  80. def username(self) -> None | str:
  81. return self.components.username
  82. @property
  83. def password(self) -> None | str:
  84. return self.components.password
  85. @property
  86. def hostname(self) -> None | str:
  87. return self.components.hostname
  88. @property
  89. def port(self) -> int | None:
  90. return self.components.port
  91. @property
  92. def is_secure(self) -> bool:
  93. return self.scheme in ("https", "wss")
  94. def replace(self, **kwargs: Any) -> URL:
  95. if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
  96. hostname = kwargs.pop("hostname", None)
  97. port = kwargs.pop("port", self.port)
  98. username = kwargs.pop("username", self.username)
  99. password = kwargs.pop("password", self.password)
  100. if hostname is None:
  101. netloc = self.netloc
  102. _, _, hostname = netloc.rpartition("@")
  103. if hostname[-1] != "]":
  104. hostname = hostname.rsplit(":", 1)[0]
  105. netloc = hostname
  106. if port is not None:
  107. netloc += f":{port}"
  108. if username is not None:
  109. userpass = username
  110. if password is not None:
  111. userpass += f":{password}"
  112. netloc = f"{userpass}@{netloc}"
  113. kwargs["netloc"] = netloc
  114. components = self.components._replace(**kwargs)
  115. return self.__class__(components.geturl())
  116. def include_query_params(self, **kwargs: Any) -> URL:
  117. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  118. params.update({str(key): str(value) for key, value in kwargs.items()})
  119. query = urlencode(params.multi_items())
  120. return self.replace(query=query)
  121. def replace_query_params(self, **kwargs: Any) -> URL:
  122. query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
  123. return self.replace(query=query)
  124. def remove_query_params(self, keys: str | Sequence[str]) -> URL:
  125. if isinstance(keys, str):
  126. keys = [keys]
  127. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  128. for key in keys:
  129. params.pop(key, None)
  130. query = urlencode(params.multi_items())
  131. return self.replace(query=query)
  132. def __eq__(self, other: Any) -> bool:
  133. return str(self) == str(other)
  134. def __str__(self) -> str:
  135. return self._url
  136. def __repr__(self) -> str:
  137. url = str(self)
  138. if self.password:
  139. url = str(self.replace(password="********"))
  140. return f"{self.__class__.__name__}({repr(url)})"
  141. class URLPath(str):
  142. """
  143. A URL path string that may also hold an associated protocol and/or host.
  144. Used by the routing to return `url_path_for` matches.
  145. """
  146. def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
  147. assert protocol in ("http", "websocket", "")
  148. return str.__new__(cls, path)
  149. def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
  150. self.protocol = protocol
  151. self.host = host
  152. def make_absolute_url(self, base_url: str | URL) -> URL:
  153. if isinstance(base_url, str):
  154. base_url = URL(base_url)
  155. if self.protocol:
  156. scheme = {
  157. "http": {True: "https", False: "http"},
  158. "websocket": {True: "wss", False: "ws"},
  159. }[self.protocol][base_url.is_secure]
  160. else:
  161. scheme = base_url.scheme
  162. netloc = self.host or base_url.netloc
  163. path = base_url.path.rstrip("/") + str(self)
  164. return URL(scheme=scheme, netloc=netloc, path=path)
  165. class Secret:
  166. """
  167. Holds a string value that should not be revealed in tracebacks etc.
  168. You should cast the value to `str` at the point it is required.
  169. """
  170. def __init__(self, value: str):
  171. self._value = value
  172. def __repr__(self) -> str:
  173. class_name = self.__class__.__name__
  174. return f"{class_name}('**********')"
  175. def __str__(self) -> str:
  176. return self._value
  177. def __bool__(self) -> bool:
  178. return bool(self._value)
  179. class CommaSeparatedStrings(Sequence[str]):
  180. def __init__(self, value: str | Sequence[str]):
  181. if isinstance(value, str):
  182. splitter = shlex(value, posix=True)
  183. splitter.whitespace = ","
  184. splitter.whitespace_split = True
  185. self._items = [item.strip() for item in splitter]
  186. else:
  187. self._items = list(value)
  188. def __len__(self) -> int:
  189. return len(self._items)
  190. def __getitem__(self, index: int | slice) -> Any:
  191. return self._items[index]
  192. def __iter__(self) -> Iterator[str]:
  193. return iter(self._items)
  194. def __repr__(self) -> str:
  195. class_name = self.__class__.__name__
  196. items = [item for item in self]
  197. return f"{class_name}({items!r})"
  198. def __str__(self) -> str:
  199. return ", ".join(repr(item) for item in self)
  200. class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
  201. _dict: dict[_KeyType, _CovariantValueType]
  202. def __init__(
  203. self,
  204. *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
  205. | Mapping[_KeyType, _CovariantValueType]
  206. | Iterable[tuple[_KeyType, _CovariantValueType]],
  207. **kwargs: Any,
  208. ) -> None:
  209. assert len(args) < 2, "Too many arguments."
  210. value: Any = args[0] if args else []
  211. if kwargs:
  212. value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
  213. if not value:
  214. _items: list[tuple[Any, Any]] = []
  215. elif hasattr(value, "multi_items"):
  216. value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
  217. _items = list(value.multi_items())
  218. elif hasattr(value, "items"):
  219. value = cast(Mapping[_KeyType, _CovariantValueType], value)
  220. _items = list(value.items())
  221. else:
  222. value = cast("list[tuple[Any, Any]]", value)
  223. _items = list(value)
  224. self._dict = {k: v for k, v in _items}
  225. self._list = _items
  226. def getlist(self, key: Any) -> list[_CovariantValueType]:
  227. return [item_value for item_key, item_value in self._list if item_key == key]
  228. def keys(self) -> KeysView[_KeyType]:
  229. return self._dict.keys()
  230. def values(self) -> ValuesView[_CovariantValueType]:
  231. return self._dict.values()
  232. def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
  233. return self._dict.items()
  234. def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
  235. return list(self._list)
  236. def __getitem__(self, key: _KeyType) -> _CovariantValueType:
  237. return self._dict[key]
  238. def __contains__(self, key: Any) -> bool:
  239. return key in self._dict
  240. def __iter__(self) -> Iterator[_KeyType]:
  241. return iter(self.keys())
  242. def __len__(self) -> int:
  243. return len(self._dict)
  244. def __eq__(self, other: Any) -> bool:
  245. if not isinstance(other, self.__class__):
  246. return False
  247. return sorted(self._list) == sorted(other._list)
  248. def __repr__(self) -> str:
  249. class_name = self.__class__.__name__
  250. items = self.multi_items()
  251. return f"{class_name}({items!r})"
  252. class MultiDict(ImmutableMultiDict[Any, Any]):
  253. def __setitem__(self, key: Any, value: Any) -> None:
  254. self.setlist(key, [value])
  255. def __delitem__(self, key: Any) -> None:
  256. self._list = [(k, v) for k, v in self._list if k != key]
  257. del self._dict[key]
  258. def pop(self, key: Any, default: Any = None) -> Any:
  259. self._list = [(k, v) for k, v in self._list if k != key]
  260. return self._dict.pop(key, default)
  261. def popitem(self) -> tuple[Any, Any]:
  262. key, value = self._dict.popitem()
  263. self._list = [(k, v) for k, v in self._list if k != key]
  264. return key, value
  265. def poplist(self, key: Any) -> list[Any]:
  266. values = [v for k, v in self._list if k == key]
  267. self.pop(key)
  268. return values
  269. def clear(self) -> None:
  270. self._dict.clear()
  271. self._list.clear()
  272. def setdefault(self, key: Any, default: Any = None) -> Any:
  273. if key not in self:
  274. self._dict[key] = default
  275. self._list.append((key, default))
  276. return self[key]
  277. def setlist(self, key: Any, values: list[Any]) -> None:
  278. if not values:
  279. self.pop(key, None)
  280. else:
  281. existing_items = [(k, v) for (k, v) in self._list if k != key]
  282. self._list = existing_items + [(key, value) for value in values]
  283. self._dict[key] = values[-1]
  284. def append(self, key: Any, value: Any) -> None:
  285. self._list.append((key, value))
  286. self._dict[key] = value
  287. def update(
  288. self,
  289. *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
  290. **kwargs: Any,
  291. ) -> None:
  292. value = MultiDict(*args, **kwargs)
  293. existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
  294. self._list = existing_items + value.multi_items()
  295. self._dict.update(value)
  296. class QueryParams(ImmutableMultiDict[str, str]):
  297. """
  298. An immutable multidict.
  299. """
  300. def __init__(
  301. self,
  302. *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
  303. **kwargs: Any,
  304. ) -> None:
  305. assert len(args) < 2, "Too many arguments."
  306. value = args[0] if args else []
  307. if isinstance(value, str):
  308. super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
  309. elif isinstance(value, bytes):
  310. super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
  311. else:
  312. super().__init__(*args, **kwargs) # type: ignore[arg-type]
  313. self._list = [(str(k), str(v)) for k, v in self._list]
  314. self._dict = {str(k): str(v) for k, v in self._dict.items()}
  315. def __str__(self) -> str:
  316. return urlencode(self._list)
  317. def __repr__(self) -> str:
  318. class_name = self.__class__.__name__
  319. query_string = str(self)
  320. return f"{class_name}({query_string!r})"
  321. class UploadFile:
  322. """
  323. An uploaded file included as part of the request data.
  324. """
  325. def __init__(
  326. self,
  327. file: BinaryIO,
  328. *,
  329. size: int | None = None,
  330. filename: str | None = None,
  331. headers: Headers | None = None,
  332. ) -> None:
  333. self.filename = filename
  334. self.file = file
  335. self.size = size
  336. self.headers = headers or Headers()
  337. # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
  338. # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
  339. self._max_mem_size = getattr(self.file, "_max_size", 0)
  340. @property
  341. def content_type(self) -> str | None:
  342. return self.headers.get("content-type", None)
  343. @property
  344. def _in_memory(self) -> bool:
  345. # check for SpooledTemporaryFile._rolled
  346. rolled_to_disk = getattr(self.file, "_rolled", True)
  347. return not rolled_to_disk
  348. def _will_roll(self, size_to_add: int) -> bool:
  349. # If we're not in_memory then we will always roll
  350. if not self._in_memory:
  351. return True
  352. # Check for SpooledTemporaryFile._max_size
  353. future_size = self.file.tell() + size_to_add
  354. return bool(future_size > self._max_mem_size) if self._max_mem_size else False
  355. async def write(self, data: bytes) -> None:
  356. new_data_len = len(data)
  357. if self.size is not None:
  358. self.size += new_data_len
  359. if self._will_roll(new_data_len):
  360. await run_in_threadpool(self.file.write, data)
  361. else:
  362. self.file.write(data)
  363. async def read(self, size: int = -1) -> bytes:
  364. if self._in_memory:
  365. return self.file.read(size)
  366. return await run_in_threadpool(self.file.read, size)
  367. async def seek(self, offset: int) -> None:
  368. if self._in_memory:
  369. self.file.seek(offset)
  370. else:
  371. await run_in_threadpool(self.file.seek, offset)
  372. async def close(self) -> None:
  373. if self._in_memory:
  374. self.file.close()
  375. else:
  376. await run_in_threadpool(self.file.close)
  377. def __repr__(self) -> str:
  378. return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
  379. class FormData(ImmutableMultiDict[str, Union[UploadFile, str]]):
  380. """
  381. An immutable multidict, containing both file uploads and text input.
  382. """
  383. def __init__(
  384. self,
  385. *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
  386. **kwargs: str | UploadFile,
  387. ) -> None:
  388. super().__init__(*args, **kwargs)
  389. async def close(self) -> None:
  390. for key, value in self.multi_items():
  391. if isinstance(value, UploadFile):
  392. await value.close()
  393. class Headers(Mapping[str, str]):
  394. """
  395. An immutable, case-insensitive multidict.
  396. """
  397. def __init__(
  398. self,
  399. headers: Mapping[str, str] | None = None,
  400. raw: list[tuple[bytes, bytes]] | None = None,
  401. scope: MutableMapping[str, Any] | None = None,
  402. ) -> None:
  403. self._list: list[tuple[bytes, bytes]] = []
  404. if headers is not None:
  405. assert raw is None, 'Cannot set both "headers" and "raw".'
  406. assert scope is None, 'Cannot set both "headers" and "scope".'
  407. self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
  408. elif raw is not None:
  409. assert scope is None, 'Cannot set both "raw" and "scope".'
  410. self._list = raw
  411. elif scope is not None:
  412. # scope["headers"] isn't necessarily a list
  413. # it might be a tuple or other iterable
  414. self._list = scope["headers"] = list(scope["headers"])
  415. @property
  416. def raw(self) -> list[tuple[bytes, bytes]]:
  417. return list(self._list)
  418. def keys(self) -> list[str]: # type: ignore[override]
  419. return [key.decode("latin-1") for key, value in self._list]
  420. def values(self) -> list[str]: # type: ignore[override]
  421. return [value.decode("latin-1") for key, value in self._list]
  422. def items(self) -> list[tuple[str, str]]: # type: ignore[override]
  423. return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
  424. def getlist(self, key: str) -> list[str]:
  425. get_header_key = key.lower().encode("latin-1")
  426. return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
  427. def mutablecopy(self) -> MutableHeaders:
  428. return MutableHeaders(raw=self._list[:])
  429. def __getitem__(self, key: str) -> str:
  430. get_header_key = key.lower().encode("latin-1")
  431. for header_key, header_value in self._list:
  432. if header_key == get_header_key:
  433. return header_value.decode("latin-1")
  434. raise KeyError(key)
  435. def __contains__(self, key: Any) -> bool:
  436. get_header_key = key.lower().encode("latin-1")
  437. for header_key, header_value in self._list:
  438. if header_key == get_header_key:
  439. return True
  440. return False
  441. def __iter__(self) -> Iterator[Any]:
  442. return iter(self.keys())
  443. def __len__(self) -> int:
  444. return len(self._list)
  445. def __eq__(self, other: Any) -> bool:
  446. if not isinstance(other, Headers):
  447. return False
  448. return sorted(self._list) == sorted(other._list)
  449. def __repr__(self) -> str:
  450. class_name = self.__class__.__name__
  451. as_dict = dict(self.items())
  452. if len(as_dict) == len(self):
  453. return f"{class_name}({as_dict!r})"
  454. return f"{class_name}(raw={self.raw!r})"
  455. class MutableHeaders(Headers):
  456. def __setitem__(self, key: str, value: str) -> None:
  457. """
  458. Set the header `key` to `value`, removing any duplicate entries.
  459. Retains insertion order.
  460. """
  461. set_key = key.lower().encode("latin-1")
  462. set_value = value.encode("latin-1")
  463. found_indexes: list[int] = []
  464. for idx, (item_key, item_value) in enumerate(self._list):
  465. if item_key == set_key:
  466. found_indexes.append(idx)
  467. for idx in reversed(found_indexes[1:]):
  468. del self._list[idx]
  469. if found_indexes:
  470. idx = found_indexes[0]
  471. self._list[idx] = (set_key, set_value)
  472. else:
  473. self._list.append((set_key, set_value))
  474. def __delitem__(self, key: str) -> None:
  475. """
  476. Remove the header `key`.
  477. """
  478. del_key = key.lower().encode("latin-1")
  479. pop_indexes: list[int] = []
  480. for idx, (item_key, item_value) in enumerate(self._list):
  481. if item_key == del_key:
  482. pop_indexes.append(idx)
  483. for idx in reversed(pop_indexes):
  484. del self._list[idx]
  485. def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
  486. if not isinstance(other, Mapping):
  487. raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
  488. self.update(other)
  489. return self
  490. def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
  491. if not isinstance(other, Mapping):
  492. raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
  493. new = self.mutablecopy()
  494. new.update(other)
  495. return new
  496. @property
  497. def raw(self) -> list[tuple[bytes, bytes]]:
  498. return self._list
  499. def setdefault(self, key: str, value: str) -> str:
  500. """
  501. If the header `key` does not exist, then set it to `value`.
  502. Returns the header value.
  503. """
  504. set_key = key.lower().encode("latin-1")
  505. set_value = value.encode("latin-1")
  506. for idx, (item_key, item_value) in enumerate(self._list):
  507. if item_key == set_key:
  508. return item_value.decode("latin-1")
  509. self._list.append((set_key, set_value))
  510. return value
  511. def update(self, other: Mapping[str, str]) -> None:
  512. for key, val in other.items():
  513. self[key] = val
  514. def append(self, key: str, value: str) -> None:
  515. """
  516. Append a header, preserving any duplicate entries.
  517. """
  518. append_key = key.lower().encode("latin-1")
  519. append_value = value.encode("latin-1")
  520. self._list.append((append_key, append_value))
  521. def add_vary_header(self, vary: str) -> None:
  522. existing = self.get("vary")
  523. if existing is not None:
  524. vary = ", ".join([existing, vary])
  525. self["vary"] = vary
  526. class State:
  527. """
  528. An object that can be used to store arbitrary state.
  529. Used for `request.state` and `app.state`.
  530. """
  531. _state: dict[str, Any]
  532. def __init__(self, state: dict[str, Any] | None = None):
  533. if state is None:
  534. state = {}
  535. super().__setattr__("_state", state)
  536. def __setattr__(self, key: Any, value: Any) -> None:
  537. self._state[key] = value
  538. def __getattr__(self, key: Any) -> Any:
  539. try:
  540. return self._state[key]
  541. except KeyError:
  542. message = "'{}' object has no attribute '{}'"
  543. raise AttributeError(message.format(self.__class__.__name__, key))
  544. def __delattr__(self, key: Any) -> None:
  545. del self._state[key]