You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

324 lines
11 KiB

  1. from __future__ import annotations
  2. import json
  3. from collections.abc import AsyncGenerator, Iterator, Mapping
  4. from http import cookies as http_cookies
  5. from typing import TYPE_CHECKING, Any, NoReturn, cast
  6. import anyio
  7. from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
  8. from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
  9. from starlette.exceptions import HTTPException
  10. from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
  11. from starlette.types import Message, Receive, Scope, Send
  12. if TYPE_CHECKING:
  13. from python_multipart.multipart import parse_options_header
  14. from starlette.applications import Starlette
  15. from starlette.routing import Router
  16. else:
  17. try:
  18. try:
  19. from python_multipart.multipart import parse_options_header
  20. except ModuleNotFoundError: # pragma: no cover
  21. from multipart.multipart import parse_options_header
  22. except ModuleNotFoundError: # pragma: no cover
  23. parse_options_header = None
  24. SERVER_PUSH_HEADERS_TO_COPY = {
  25. "accept",
  26. "accept-encoding",
  27. "accept-language",
  28. "cache-control",
  29. "user-agent",
  30. }
  31. def cookie_parser(cookie_string: str) -> dict[str, str]:
  32. """
  33. This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
  34. It attempts to mimic browser cookie parsing behavior: browsers and web servers
  35. frequently disregard the spec (RFC 6265) when setting and reading cookies,
  36. so we attempt to suit the common scenarios here.
  37. This function has been adapted from Django 3.1.0.
  38. Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
  39. on an outdated spec and will fail on lots of input we want to support
  40. """
  41. cookie_dict: dict[str, str] = {}
  42. for chunk in cookie_string.split(";"):
  43. if "=" in chunk:
  44. key, val = chunk.split("=", 1)
  45. else:
  46. # Assume an empty name per
  47. # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
  48. key, val = "", chunk
  49. key, val = key.strip(), val.strip()
  50. if key or val:
  51. # unquote using Python's algorithm.
  52. cookie_dict[key] = http_cookies._unquote(val)
  53. return cookie_dict
  54. class ClientDisconnect(Exception):
  55. pass
  56. class HTTPConnection(Mapping[str, Any]):
  57. """
  58. A base class for incoming HTTP connections, that is used to provide
  59. any functionality that is common to both `Request` and `WebSocket`.
  60. """
  61. def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
  62. assert scope["type"] in ("http", "websocket")
  63. self.scope = scope
  64. def __getitem__(self, key: str) -> Any:
  65. return self.scope[key]
  66. def __iter__(self) -> Iterator[str]:
  67. return iter(self.scope)
  68. def __len__(self) -> int:
  69. return len(self.scope)
  70. # Don't use the `abc.Mapping.__eq__` implementation.
  71. # Connection instances should never be considered equal
  72. # unless `self is other`.
  73. __eq__ = object.__eq__
  74. __hash__ = object.__hash__
  75. @property
  76. def app(self) -> Any:
  77. return self.scope["app"]
  78. @property
  79. def url(self) -> URL:
  80. if not hasattr(self, "_url"): # pragma: no branch
  81. self._url = URL(scope=self.scope)
  82. return self._url
  83. @property
  84. def base_url(self) -> URL:
  85. if not hasattr(self, "_base_url"):
  86. base_url_scope = dict(self.scope)
  87. # This is used by request.url_for, it might be used inside a Mount which
  88. # would have its own child scope with its own root_path, but the base URL
  89. # for url_for should still be the top level app root path.
  90. app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
  91. path = app_root_path
  92. if not path.endswith("/"):
  93. path += "/"
  94. base_url_scope["path"] = path
  95. base_url_scope["query_string"] = b""
  96. base_url_scope["root_path"] = app_root_path
  97. self._base_url = URL(scope=base_url_scope)
  98. return self._base_url
  99. @property
  100. def headers(self) -> Headers:
  101. if not hasattr(self, "_headers"):
  102. self._headers = Headers(scope=self.scope)
  103. return self._headers
  104. @property
  105. def query_params(self) -> QueryParams:
  106. if not hasattr(self, "_query_params"): # pragma: no branch
  107. self._query_params = QueryParams(self.scope["query_string"])
  108. return self._query_params
  109. @property
  110. def path_params(self) -> dict[str, Any]:
  111. return self.scope.get("path_params", {})
  112. @property
  113. def cookies(self) -> dict[str, str]:
  114. if not hasattr(self, "_cookies"):
  115. cookies: dict[str, str] = {}
  116. cookie_header = self.headers.get("cookie")
  117. if cookie_header:
  118. cookies = cookie_parser(cookie_header)
  119. self._cookies = cookies
  120. return self._cookies
  121. @property
  122. def client(self) -> Address | None:
  123. # client is a 2 item tuple of (host, port), None if missing
  124. host_port = self.scope.get("client")
  125. if host_port is not None:
  126. return Address(*host_port)
  127. return None
  128. @property
  129. def session(self) -> dict[str, Any]:
  130. assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
  131. return self.scope["session"] # type: ignore[no-any-return]
  132. @property
  133. def auth(self) -> Any:
  134. assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
  135. return self.scope["auth"]
  136. @property
  137. def user(self) -> Any:
  138. assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
  139. return self.scope["user"]
  140. @property
  141. def state(self) -> State:
  142. if not hasattr(self, "_state"):
  143. # Ensure 'state' has an empty dict if it's not already populated.
  144. self.scope.setdefault("state", {})
  145. # Create a state instance with a reference to the dict in which it should
  146. # store info
  147. self._state = State(self.scope["state"])
  148. return self._state
  149. def url_for(self, name: str, /, **path_params: Any) -> URL:
  150. url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
  151. if url_path_provider is None:
  152. raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
  153. url_path = url_path_provider.url_path_for(name, **path_params)
  154. return url_path.make_absolute_url(base_url=self.base_url)
  155. async def empty_receive() -> NoReturn:
  156. raise RuntimeError("Receive channel has not been made available")
  157. async def empty_send(message: Message) -> NoReturn:
  158. raise RuntimeError("Send channel has not been made available")
  159. class Request(HTTPConnection):
  160. _form: FormData | None
  161. def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
  162. super().__init__(scope)
  163. assert scope["type"] == "http"
  164. self._receive = receive
  165. self._send = send
  166. self._stream_consumed = False
  167. self._is_disconnected = False
  168. self._form = None
  169. @property
  170. def method(self) -> str:
  171. return cast(str, self.scope["method"])
  172. @property
  173. def receive(self) -> Receive:
  174. return self._receive
  175. async def stream(self) -> AsyncGenerator[bytes, None]:
  176. if hasattr(self, "_body"):
  177. yield self._body
  178. yield b""
  179. return
  180. if self._stream_consumed:
  181. raise RuntimeError("Stream consumed")
  182. while not self._stream_consumed:
  183. message = await self._receive()
  184. if message["type"] == "http.request":
  185. body = message.get("body", b"")
  186. if not message.get("more_body", False):
  187. self._stream_consumed = True
  188. if body:
  189. yield body
  190. elif message["type"] == "http.disconnect": # pragma: no branch
  191. self._is_disconnected = True
  192. raise ClientDisconnect()
  193. yield b""
  194. async def body(self) -> bytes:
  195. if not hasattr(self, "_body"):
  196. chunks: list[bytes] = []
  197. async for chunk in self.stream():
  198. chunks.append(chunk)
  199. self._body = b"".join(chunks)
  200. return self._body
  201. async def json(self) -> Any:
  202. if not hasattr(self, "_json"): # pragma: no branch
  203. body = await self.body()
  204. self._json = json.loads(body)
  205. return self._json
  206. async def _get_form(
  207. self,
  208. *,
  209. max_files: int | float = 1000,
  210. max_fields: int | float = 1000,
  211. max_part_size: int = 1024 * 1024,
  212. ) -> FormData:
  213. if self._form is None: # pragma: no branch
  214. assert parse_options_header is not None, (
  215. "The `python-multipart` library must be installed to use form parsing."
  216. )
  217. content_type_header = self.headers.get("Content-Type")
  218. content_type: bytes
  219. content_type, _ = parse_options_header(content_type_header)
  220. if content_type == b"multipart/form-data":
  221. try:
  222. multipart_parser = MultiPartParser(
  223. self.headers,
  224. self.stream(),
  225. max_files=max_files,
  226. max_fields=max_fields,
  227. max_part_size=max_part_size,
  228. )
  229. self._form = await multipart_parser.parse()
  230. except MultiPartException as exc:
  231. if "app" in self.scope:
  232. raise HTTPException(status_code=400, detail=exc.message)
  233. raise exc
  234. elif content_type == b"application/x-www-form-urlencoded":
  235. form_parser = FormParser(self.headers, self.stream())
  236. self._form = await form_parser.parse()
  237. else:
  238. self._form = FormData()
  239. return self._form
  240. def form(
  241. self,
  242. *,
  243. max_files: int | float = 1000,
  244. max_fields: int | float = 1000,
  245. max_part_size: int = 1024 * 1024,
  246. ) -> AwaitableOrContextManager[FormData]:
  247. return AwaitableOrContextManagerWrapper(
  248. self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
  249. )
  250. async def close(self) -> None:
  251. if self._form is not None: # pragma: no branch
  252. await self._form.close()
  253. async def is_disconnected(self) -> bool:
  254. if not self._is_disconnected:
  255. message: Message = {}
  256. # If message isn't immediately available, move on
  257. with anyio.CancelScope() as cs:
  258. cs.cancel()
  259. message = await self._receive()
  260. if message.get("type") == "http.disconnect":
  261. self._is_disconnected = True
  262. return self._is_disconnected
  263. async def send_push_promise(self, path: str) -> None:
  264. if "http.response.push" in self.scope.get("extensions", {}):
  265. raw_headers: list[tuple[bytes, bytes]] = []
  266. for name in SERVER_PUSH_HEADERS_TO_COPY:
  267. for value in self.headers.getlist(name):
  268. raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
  269. await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})