您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

668 行
21 KiB

  1. import tempfile
  2. import typing
  3. from collections import namedtuple
  4. from collections.abc import Sequence
  5. from shlex import shlex
  6. from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
  7. from starlette.concurrency import run_in_threadpool
  8. from starlette.types import Scope
  9. Address = namedtuple("Address", ["host", "port"])
  10. class URL:
  11. def __init__(
  12. self, url: str = "", scope: Scope = None, **components: typing.Any
  13. ) -> None:
  14. if scope is not None:
  15. assert not url, 'Cannot set both "url" and "scope".'
  16. assert not components, 'Cannot set both "scope" and "**components".'
  17. scheme = scope.get("scheme", "http")
  18. server = scope.get("server", None)
  19. path = scope.get("root_path", "") + scope["path"]
  20. query_string = scope.get("query_string", b"")
  21. host_header = None
  22. for key, value in scope["headers"]:
  23. if key == b"host":
  24. host_header = value.decode("latin-1")
  25. break
  26. if host_header is not None:
  27. url = f"{scheme}://{host_header}{path}"
  28. elif server is None:
  29. url = path
  30. else:
  31. host, port = server
  32. default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
  33. if port == default_port:
  34. url = f"{scheme}://{host}{path}"
  35. else:
  36. url = f"{scheme}://{host}:{port}{path}"
  37. if query_string:
  38. url += "?" + query_string.decode()
  39. elif components:
  40. assert not url, 'Cannot set both "url" and "**components".'
  41. url = URL("").replace(**components).components.geturl()
  42. self._url = url
  43. @property
  44. def components(self) -> SplitResult:
  45. if not hasattr(self, "_components"):
  46. self._components = urlsplit(self._url)
  47. return self._components
  48. @property
  49. def scheme(self) -> str:
  50. return self.components.scheme
  51. @property
  52. def netloc(self) -> str:
  53. return self.components.netloc
  54. @property
  55. def path(self) -> str:
  56. return self.components.path
  57. @property
  58. def query(self) -> str:
  59. return self.components.query
  60. @property
  61. def fragment(self) -> str:
  62. return self.components.fragment
  63. @property
  64. def username(self) -> typing.Union[None, str]:
  65. return self.components.username
  66. @property
  67. def password(self) -> typing.Union[None, str]:
  68. return self.components.password
  69. @property
  70. def hostname(self) -> typing.Union[None, str]:
  71. return self.components.hostname
  72. @property
  73. def port(self) -> typing.Optional[int]:
  74. return self.components.port
  75. @property
  76. def is_secure(self) -> bool:
  77. return self.scheme in ("https", "wss")
  78. def replace(self, **kwargs: typing.Any) -> "URL":
  79. if (
  80. "username" in kwargs
  81. or "password" in kwargs
  82. or "hostname" in kwargs
  83. or "port" in kwargs
  84. ):
  85. hostname = kwargs.pop("hostname", self.hostname)
  86. port = kwargs.pop("port", self.port)
  87. username = kwargs.pop("username", self.username)
  88. password = kwargs.pop("password", self.password)
  89. netloc = hostname
  90. if port is not None:
  91. netloc += f":{port}"
  92. if username is not None:
  93. userpass = username
  94. if password is not None:
  95. userpass += f":{password}"
  96. netloc = f"{userpass}@{netloc}"
  97. kwargs["netloc"] = netloc
  98. components = self.components._replace(**kwargs)
  99. return self.__class__(components.geturl())
  100. def include_query_params(self, **kwargs: typing.Any) -> "URL":
  101. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  102. params.update({str(key): str(value) for key, value in kwargs.items()})
  103. query = urlencode(params.multi_items())
  104. return self.replace(query=query)
  105. def replace_query_params(self, **kwargs: typing.Any) -> "URL":
  106. query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
  107. return self.replace(query=query)
  108. def remove_query_params(
  109. self, keys: typing.Union[str, typing.Sequence[str]]
  110. ) -> "URL":
  111. if isinstance(keys, str):
  112. keys = [keys]
  113. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  114. for key in keys:
  115. params.pop(key, None)
  116. query = urlencode(params.multi_items())
  117. return self.replace(query=query)
  118. def __eq__(self, other: typing.Any) -> bool:
  119. return str(self) == str(other)
  120. def __str__(self) -> str:
  121. return self._url
  122. def __repr__(self) -> str:
  123. url = str(self)
  124. if self.password:
  125. url = str(self.replace(password="********"))
  126. return f"{self.__class__.__name__}({repr(url)})"
  127. class URLPath(str):
  128. """
  129. A URL path string that may also hold an associated protocol and/or host.
  130. Used by the routing to return `url_path_for` matches.
  131. """
  132. def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
  133. assert protocol in ("http", "websocket", "")
  134. return str.__new__(cls, path) # type: ignore
  135. def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
  136. self.protocol = protocol
  137. self.host = host
  138. def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
  139. if isinstance(base_url, str):
  140. base_url = URL(base_url)
  141. if self.protocol:
  142. scheme = {
  143. "http": {True: "https", False: "http"},
  144. "websocket": {True: "wss", False: "ws"},
  145. }[self.protocol][base_url.is_secure]
  146. else:
  147. scheme = base_url.scheme
  148. netloc = self.host or base_url.netloc
  149. path = base_url.path.rstrip("/") + str(self)
  150. return str(URL(scheme=scheme, netloc=netloc, path=path))
  151. class Secret:
  152. """
  153. Holds a string value that should not be revealed in tracebacks etc.
  154. You should cast the value to `str` at the point it is required.
  155. """
  156. def __init__(self, value: str):
  157. self._value = value
  158. def __repr__(self) -> str:
  159. class_name = self.__class__.__name__
  160. return f"{class_name}('**********')"
  161. def __str__(self) -> str:
  162. return self._value
  163. class CommaSeparatedStrings(Sequence):
  164. def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
  165. if isinstance(value, str):
  166. splitter = shlex(value, posix=True)
  167. splitter.whitespace = ","
  168. splitter.whitespace_split = True
  169. self._items = [item.strip() for item in splitter]
  170. else:
  171. self._items = list(value)
  172. def __len__(self) -> int:
  173. return len(self._items)
  174. def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
  175. return self._items[index]
  176. def __iter__(self) -> typing.Iterator[str]:
  177. return iter(self._items)
  178. def __repr__(self) -> str:
  179. class_name = self.__class__.__name__
  180. items = [item for item in self]
  181. return f"{class_name}({items!r})"
  182. def __str__(self) -> str:
  183. return ", ".join(repr(item) for item in self)
  184. class ImmutableMultiDict(typing.Mapping):
  185. def __init__(
  186. self,
  187. *args: typing.Union[
  188. "ImmutableMultiDict",
  189. typing.Mapping,
  190. typing.List[typing.Tuple[typing.Any, typing.Any]],
  191. ],
  192. **kwargs: typing.Any,
  193. ) -> None:
  194. assert len(args) < 2, "Too many arguments."
  195. value = args[0] if args else []
  196. if kwargs:
  197. value = (
  198. ImmutableMultiDict(value).multi_items()
  199. + ImmutableMultiDict(kwargs).multi_items()
  200. )
  201. if not value:
  202. _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
  203. elif hasattr(value, "multi_items"):
  204. value = typing.cast(ImmutableMultiDict, value)
  205. _items = list(value.multi_items())
  206. elif hasattr(value, "items"):
  207. value = typing.cast(typing.Mapping, value)
  208. _items = list(value.items())
  209. else:
  210. value = typing.cast(
  211. typing.List[typing.Tuple[typing.Any, typing.Any]], value
  212. )
  213. _items = list(value)
  214. self._dict = {k: v for k, v in _items}
  215. self._list = _items
  216. def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
  217. return [item_value for item_key, item_value in self._list if item_key == key]
  218. def keys(self) -> typing.KeysView:
  219. return self._dict.keys()
  220. def values(self) -> typing.ValuesView:
  221. return self._dict.values()
  222. def items(self) -> typing.ItemsView:
  223. return self._dict.items()
  224. def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
  225. return list(self._list)
  226. def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  227. if key in self._dict:
  228. return self._dict[key]
  229. return default
  230. def __getitem__(self, key: typing.Any) -> str:
  231. return self._dict[key]
  232. def __contains__(self, key: typing.Any) -> bool:
  233. return key in self._dict
  234. def __iter__(self) -> typing.Iterator[typing.Any]:
  235. return iter(self.keys())
  236. def __len__(self) -> int:
  237. return len(self._dict)
  238. def __eq__(self, other: typing.Any) -> bool:
  239. if not isinstance(other, self.__class__):
  240. return False
  241. return sorted(self._list) == sorted(other._list)
  242. def __repr__(self) -> str:
  243. class_name = self.__class__.__name__
  244. items = self.multi_items()
  245. return f"{class_name}({items!r})"
  246. class MultiDict(ImmutableMultiDict):
  247. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  248. self.setlist(key, [value])
  249. def __delitem__(self, key: typing.Any) -> None:
  250. self._list = [(k, v) for k, v in self._list if k != key]
  251. del self._dict[key]
  252. def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  253. self._list = [(k, v) for k, v in self._list if k != key]
  254. return self._dict.pop(key, default)
  255. def popitem(self) -> typing.Tuple:
  256. key, value = self._dict.popitem()
  257. self._list = [(k, v) for k, v in self._list if k != key]
  258. return key, value
  259. def poplist(self, key: typing.Any) -> typing.List:
  260. values = [v for k, v in self._list if k == key]
  261. self.pop(key)
  262. return values
  263. def clear(self) -> None:
  264. self._dict.clear()
  265. self._list.clear()
  266. def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  267. if key not in self:
  268. self._dict[key] = default
  269. self._list.append((key, default))
  270. return self[key]
  271. def setlist(self, key: typing.Any, values: typing.List) -> None:
  272. if not values:
  273. self.pop(key, None)
  274. else:
  275. existing_items = [(k, v) for (k, v) in self._list if k != key]
  276. self._list = existing_items + [(key, value) for value in values]
  277. self._dict[key] = values[-1]
  278. def append(self, key: typing.Any, value: typing.Any) -> None:
  279. self._list.append((key, value))
  280. self._dict[key] = value
  281. def update(
  282. self,
  283. *args: typing.Union[
  284. "MultiDict",
  285. typing.Mapping,
  286. typing.List[typing.Tuple[typing.Any, typing.Any]],
  287. ],
  288. **kwargs: typing.Any,
  289. ) -> None:
  290. value = MultiDict(*args, **kwargs)
  291. existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
  292. self._list = existing_items + value.multi_items()
  293. self._dict.update(value)
  294. class QueryParams(ImmutableMultiDict):
  295. """
  296. An immutable multidict.
  297. """
  298. def __init__(
  299. self,
  300. *args: typing.Union[
  301. "ImmutableMultiDict",
  302. typing.Mapping,
  303. typing.List[typing.Tuple[typing.Any, typing.Any]],
  304. str,
  305. bytes,
  306. ],
  307. **kwargs: typing.Any,
  308. ) -> None:
  309. assert len(args) < 2, "Too many arguments."
  310. value = args[0] if args else []
  311. if isinstance(value, str):
  312. super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
  313. elif isinstance(value, bytes):
  314. super().__init__(
  315. parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
  316. )
  317. else:
  318. super().__init__(*args, **kwargs) # type: ignore
  319. self._list = [(str(k), str(v)) for k, v in self._list]
  320. self._dict = {str(k): str(v) for k, v in self._dict.items()}
  321. def __str__(self) -> str:
  322. return urlencode(self._list)
  323. def __repr__(self) -> str:
  324. class_name = self.__class__.__name__
  325. query_string = str(self)
  326. return f"{class_name}({query_string!r})"
  327. class UploadFile:
  328. """
  329. An uploaded file included as part of the request data.
  330. """
  331. spool_max_size = 1024 * 1024
  332. def __init__(
  333. self, filename: str, file: typing.IO = None, content_type: str = ""
  334. ) -> None:
  335. self.filename = filename
  336. self.content_type = content_type
  337. if file is None:
  338. file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
  339. self.file = file
  340. @property
  341. def _in_memory(self) -> bool:
  342. rolled_to_disk = getattr(self.file, "_rolled", True)
  343. return not rolled_to_disk
  344. async def write(self, data: typing.Union[bytes, str]) -> None:
  345. if self._in_memory:
  346. self.file.write(data) # type: ignore
  347. else:
  348. await run_in_threadpool(self.file.write, data)
  349. async def read(self, size: int = -1) -> typing.Union[bytes, str]:
  350. if self._in_memory:
  351. return self.file.read(size)
  352. return await run_in_threadpool(self.file.read, size)
  353. async def seek(self, offset: int) -> None:
  354. if self._in_memory:
  355. self.file.seek(offset)
  356. else:
  357. await run_in_threadpool(self.file.seek, offset)
  358. async def close(self) -> None:
  359. if self._in_memory:
  360. self.file.close()
  361. else:
  362. await run_in_threadpool(self.file.close)
  363. class FormData(ImmutableMultiDict):
  364. """
  365. An immutable multidict, containing both file uploads and text input.
  366. """
  367. def __init__(
  368. self,
  369. *args: typing.Union[
  370. "FormData",
  371. typing.Mapping[str, typing.Union[str, UploadFile]],
  372. typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
  373. ],
  374. **kwargs: typing.Union[str, UploadFile],
  375. ) -> None:
  376. super().__init__(*args, **kwargs)
  377. async def close(self) -> None:
  378. for key, value in self.multi_items():
  379. if isinstance(value, UploadFile):
  380. await value.close()
  381. class Headers(typing.Mapping[str, str]):
  382. """
  383. An immutable, case-insensitive multidict.
  384. """
  385. def __init__(
  386. self,
  387. headers: typing.Mapping[str, str] = None,
  388. raw: typing.List[typing.Tuple[bytes, bytes]] = None,
  389. scope: Scope = None,
  390. ) -> None:
  391. self._list: typing.List[typing.Tuple[bytes, bytes]] = []
  392. if headers is not None:
  393. assert raw is None, 'Cannot set both "headers" and "raw".'
  394. assert scope is None, 'Cannot set both "headers" and "scope".'
  395. self._list = [
  396. (key.lower().encode("latin-1"), value.encode("latin-1"))
  397. for key, value in headers.items()
  398. ]
  399. elif raw is not None:
  400. assert scope is None, 'Cannot set both "raw" and "scope".'
  401. self._list = raw
  402. elif scope is not None:
  403. self._list = scope["headers"]
  404. @property
  405. def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
  406. return list(self._list)
  407. def keys(self) -> typing.List[str]: # type: ignore
  408. return [key.decode("latin-1") for key, value in self._list]
  409. def values(self) -> typing.List[str]: # type: ignore
  410. return [value.decode("latin-1") for key, value in self._list]
  411. def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
  412. return [
  413. (key.decode("latin-1"), value.decode("latin-1"))
  414. for key, value in self._list
  415. ]
  416. def get(self, key: str, default: typing.Any = None) -> typing.Any:
  417. try:
  418. return self[key]
  419. except KeyError:
  420. return default
  421. def getlist(self, key: str) -> typing.List[str]:
  422. get_header_key = key.lower().encode("latin-1")
  423. return [
  424. item_value.decode("latin-1")
  425. for item_key, item_value in self._list
  426. if item_key == get_header_key
  427. ]
  428. def mutablecopy(self) -> "MutableHeaders":
  429. return MutableHeaders(raw=self._list[:])
  430. def __getitem__(self, key: str) -> str:
  431. get_header_key = key.lower().encode("latin-1")
  432. for header_key, header_value in self._list:
  433. if header_key == get_header_key:
  434. return header_value.decode("latin-1")
  435. raise KeyError(key)
  436. def __contains__(self, key: typing.Any) -> bool:
  437. get_header_key = key.lower().encode("latin-1")
  438. for header_key, header_value in self._list:
  439. if header_key == get_header_key:
  440. return True
  441. return False
  442. def __iter__(self) -> typing.Iterator[typing.Any]:
  443. return iter(self.keys())
  444. def __len__(self) -> int:
  445. return len(self._list)
  446. def __eq__(self, other: typing.Any) -> bool:
  447. if not isinstance(other, Headers):
  448. return False
  449. return sorted(self._list) == sorted(other._list)
  450. def __repr__(self) -> str:
  451. class_name = self.__class__.__name__
  452. as_dict = dict(self.items())
  453. if len(as_dict) == len(self):
  454. return f"{class_name}({as_dict!r})"
  455. return f"{class_name}(raw={self.raw!r})"
  456. class MutableHeaders(Headers):
  457. def __setitem__(self, key: str, value: str) -> None:
  458. """
  459. Set the header `key` to `value`, removing any duplicate entries.
  460. Retains insertion order.
  461. """
  462. set_key = key.lower().encode("latin-1")
  463. set_value = value.encode("latin-1")
  464. found_indexes = []
  465. for idx, (item_key, item_value) in enumerate(self._list):
  466. if item_key == set_key:
  467. found_indexes.append(idx)
  468. for idx in reversed(found_indexes[1:]):
  469. del self._list[idx]
  470. if found_indexes:
  471. idx = found_indexes[0]
  472. self._list[idx] = (set_key, set_value)
  473. else:
  474. self._list.append((set_key, set_value))
  475. def __delitem__(self, key: str) -> None:
  476. """
  477. Remove the header `key`.
  478. """
  479. del_key = key.lower().encode("latin-1")
  480. pop_indexes = []
  481. for idx, (item_key, item_value) in enumerate(self._list):
  482. if item_key == del_key:
  483. pop_indexes.append(idx)
  484. for idx in reversed(pop_indexes):
  485. del self._list[idx]
  486. @property
  487. def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
  488. return self._list
  489. def setdefault(self, key: str, value: str) -> str:
  490. """
  491. If the header `key` does not exist, then set it to `value`.
  492. Returns the header value.
  493. """
  494. set_key = key.lower().encode("latin-1")
  495. set_value = value.encode("latin-1")
  496. for idx, (item_key, item_value) in enumerate(self._list):
  497. if item_key == set_key:
  498. return item_value.decode("latin-1")
  499. self._list.append((set_key, set_value))
  500. return value
  501. def update(self, other: dict) -> None:
  502. for key, val in other.items():
  503. self[key] = val
  504. def append(self, key: str, value: str) -> None:
  505. """
  506. Append a header, preserving any duplicate entries.
  507. """
  508. append_key = key.lower().encode("latin-1")
  509. append_value = value.encode("latin-1")
  510. self._list.append((append_key, append_value))
  511. def add_vary_header(self, vary: str) -> None:
  512. existing = self.get("vary")
  513. if existing is not None:
  514. vary = ", ".join([existing, vary])
  515. self["vary"] = vary
  516. class State:
  517. """
  518. An object that can be used to store arbitrary state.
  519. Used for `request.state` and `app.state`.
  520. """
  521. def __init__(self, state: typing.Dict = None):
  522. if state is None:
  523. state = {}
  524. super().__setattr__("_state", state)
  525. def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
  526. self._state[key] = value
  527. def __getattr__(self, key: typing.Any) -> typing.Any:
  528. try:
  529. return self._state[key]
  530. except KeyError:
  531. message = "'{}' object has no attribute '{}'"
  532. raise AttributeError(message.format(self.__class__.__name__, key))
  533. def __delattr__(self, key: typing.Any) -> None:
  534. del self._state[key]