您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

284 行
9.1 KiB

  1. import json
  2. import typing
  3. from collections.abc import Mapping
  4. from http import cookies as http_cookies
  5. import anyio
  6. from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
  7. from starlette.formparsers import FormParser, MultiPartParser
  8. from starlette.types import Message, Receive, Scope, Send
  9. try:
  10. from multipart.multipart import parse_options_header
  11. except ImportError: # pragma: nocover
  12. parse_options_header = None
  13. SERVER_PUSH_HEADERS_TO_COPY = {
  14. "accept",
  15. "accept-encoding",
  16. "accept-language",
  17. "cache-control",
  18. "user-agent",
  19. }
  20. def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
  21. """
  22. This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
  23. It attempts to mimic browser cookie parsing behavior: browsers and web servers
  24. frequently disregard the spec (RFC 6265) when setting and reading cookies,
  25. so we attempt to suit the common scenarios here.
  26. This function has been adapted from Django 3.1.0.
  27. Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
  28. on an outdated spec and will fail on lots of input we want to support
  29. """
  30. cookie_dict: typing.Dict[str, str] = {}
  31. for chunk in cookie_string.split(";"):
  32. if "=" in chunk:
  33. key, val = chunk.split("=", 1)
  34. else:
  35. # Assume an empty name per
  36. # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
  37. key, val = "", chunk
  38. key, val = key.strip(), val.strip()
  39. if key or val:
  40. # unquote using Python's algorithm.
  41. cookie_dict[key] = http_cookies._unquote(val) # type: ignore
  42. return cookie_dict
  43. class ClientDisconnect(Exception):
  44. pass
  45. class HTTPConnection(Mapping):
  46. """
  47. A base class for incoming HTTP connections, that is used to provide
  48. any functionality that is common to both `Request` and `WebSocket`.
  49. """
  50. def __init__(self, scope: Scope, receive: Receive = None) -> None:
  51. assert scope["type"] in ("http", "websocket")
  52. self.scope = scope
  53. def __getitem__(self, key: str) -> typing.Any:
  54. return self.scope[key]
  55. def __iter__(self) -> typing.Iterator[str]:
  56. return iter(self.scope)
  57. def __len__(self) -> int:
  58. return len(self.scope)
  59. # Don't use the `abc.Mapping.__eq__` implementation.
  60. # Connection instances should never be considered equal
  61. # unless `self is other`.
  62. __eq__ = object.__eq__
  63. __hash__ = object.__hash__
  64. @property
  65. def app(self) -> typing.Any:
  66. return self.scope["app"]
  67. @property
  68. def url(self) -> URL:
  69. if not hasattr(self, "_url"):
  70. self._url = URL(scope=self.scope)
  71. return self._url
  72. @property
  73. def base_url(self) -> URL:
  74. if not hasattr(self, "_base_url"):
  75. base_url_scope = dict(self.scope)
  76. base_url_scope["path"] = "/"
  77. base_url_scope["query_string"] = b""
  78. base_url_scope["root_path"] = base_url_scope.get(
  79. "app_root_path", base_url_scope.get("root_path", "")
  80. )
  81. self._base_url = URL(scope=base_url_scope)
  82. return self._base_url
  83. @property
  84. def headers(self) -> Headers:
  85. if not hasattr(self, "_headers"):
  86. self._headers = Headers(scope=self.scope)
  87. return self._headers
  88. @property
  89. def query_params(self) -> QueryParams:
  90. if not hasattr(self, "_query_params"):
  91. self._query_params = QueryParams(self.scope["query_string"])
  92. return self._query_params
  93. @property
  94. def path_params(self) -> dict:
  95. return self.scope.get("path_params", {})
  96. @property
  97. def cookies(self) -> typing.Dict[str, str]:
  98. if not hasattr(self, "_cookies"):
  99. cookies: typing.Dict[str, str] = {}
  100. cookie_header = self.headers.get("cookie")
  101. if cookie_header:
  102. cookies = cookie_parser(cookie_header)
  103. self._cookies = cookies
  104. return self._cookies
  105. @property
  106. def client(self) -> Address:
  107. host, port = self.scope.get("client") or (None, None)
  108. return Address(host=host, port=port)
  109. @property
  110. def session(self) -> dict:
  111. assert (
  112. "session" in self.scope
  113. ), "SessionMiddleware must be installed to access request.session"
  114. return self.scope["session"]
  115. @property
  116. def auth(self) -> typing.Any:
  117. assert (
  118. "auth" in self.scope
  119. ), "AuthenticationMiddleware must be installed to access request.auth"
  120. return self.scope["auth"]
  121. @property
  122. def user(self) -> typing.Any:
  123. assert (
  124. "user" in self.scope
  125. ), "AuthenticationMiddleware must be installed to access request.user"
  126. return self.scope["user"]
  127. @property
  128. def state(self) -> State:
  129. if not hasattr(self, "_state"):
  130. # Ensure 'state' has an empty dict if it's not already populated.
  131. self.scope.setdefault("state", {})
  132. # Create a state instance with a reference to the dict in which it should
  133. # store info
  134. self._state = State(self.scope["state"])
  135. return self._state
  136. def url_for(self, name: str, **path_params: typing.Any) -> str:
  137. router = self.scope["router"]
  138. url_path = router.url_path_for(name, **path_params)
  139. return url_path.make_absolute_url(base_url=self.base_url)
  140. async def empty_receive() -> Message:
  141. raise RuntimeError("Receive channel has not been made available")
  142. async def empty_send(message: Message) -> None:
  143. raise RuntimeError("Send channel has not been made available")
  144. class Request(HTTPConnection):
  145. def __init__(
  146. self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
  147. ):
  148. super().__init__(scope)
  149. assert scope["type"] == "http"
  150. self._receive = receive
  151. self._send = send
  152. self._stream_consumed = False
  153. self._is_disconnected = False
  154. @property
  155. def method(self) -> str:
  156. return self.scope["method"]
  157. @property
  158. def receive(self) -> Receive:
  159. return self._receive
  160. async def stream(self) -> typing.AsyncGenerator[bytes, None]:
  161. if hasattr(self, "_body"):
  162. yield self._body
  163. yield b""
  164. return
  165. if self._stream_consumed:
  166. raise RuntimeError("Stream consumed")
  167. self._stream_consumed = True
  168. while True:
  169. message = await self._receive()
  170. if message["type"] == "http.request":
  171. body = message.get("body", b"")
  172. if body:
  173. yield body
  174. if not message.get("more_body", False):
  175. break
  176. elif message["type"] == "http.disconnect":
  177. self._is_disconnected = True
  178. raise ClientDisconnect()
  179. yield b""
  180. async def body(self) -> bytes:
  181. if not hasattr(self, "_body"):
  182. chunks = []
  183. async for chunk in self.stream():
  184. chunks.append(chunk)
  185. self._body = b"".join(chunks)
  186. return self._body
  187. async def json(self) -> typing.Any:
  188. if not hasattr(self, "_json"):
  189. body = await self.body()
  190. self._json = json.loads(body)
  191. return self._json
  192. async def form(self) -> FormData:
  193. if not hasattr(self, "_form"):
  194. assert (
  195. parse_options_header is not None
  196. ), "The `python-multipart` library must be installed to use form parsing."
  197. content_type_header = self.headers.get("Content-Type")
  198. content_type, options = parse_options_header(content_type_header)
  199. if content_type == b"multipart/form-data":
  200. multipart_parser = MultiPartParser(self.headers, self.stream())
  201. self._form = await multipart_parser.parse()
  202. elif content_type == b"application/x-www-form-urlencoded":
  203. form_parser = FormParser(self.headers, self.stream())
  204. self._form = await form_parser.parse()
  205. else:
  206. self._form = FormData()
  207. return self._form
  208. async def close(self) -> None:
  209. if hasattr(self, "_form"):
  210. await self._form.close()
  211. async def is_disconnected(self) -> bool:
  212. if not self._is_disconnected:
  213. message: Message = {}
  214. # If message isn't immediately available, move on
  215. with anyio.CancelScope() as cs:
  216. cs.cancel()
  217. message = await self._receive()
  218. if message.get("type") == "http.disconnect":
  219. self._is_disconnected = True
  220. return self._is_disconnected
  221. async def send_push_promise(self, path: str) -> None:
  222. if "http.response.push" in self.scope.get("extensions", {}):
  223. raw_headers = []
  224. for name in SERVER_PUSH_HEADERS_TO_COPY:
  225. for value in self.headers.getlist(name):
  226. raw_headers.append(
  227. (name.encode("latin-1"), value.encode("latin-1"))
  228. )
  229. await self._send(
  230. {"type": "http.response.push", "path": path, "headers": raw_headers}
  231. )