25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

236 lines
9.4 KiB

  1. from __future__ import annotations
  2. from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Mapping, MutableMapping
  3. from typing import Any, Callable, TypeVar, Union
  4. import anyio
  5. from starlette._utils import collapse_excgroups
  6. from starlette.requests import ClientDisconnect, Request
  7. from starlette.responses import Response
  8. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  9. RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
  10. DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
  11. BodyStreamGenerator = AsyncGenerator[Union[bytes, MutableMapping[str, Any]], None]
  12. AsyncContentStream = AsyncIterable[Union[str, bytes, memoryview, MutableMapping[str, Any]]]
  13. T = TypeVar("T")
  14. class _CachedRequest(Request):
  15. """
  16. If the user calls Request.body() from their dispatch function
  17. we cache the entire request body in memory and pass that to downstream middlewares,
  18. but if they call Request.stream() then all we do is send an
  19. empty body so that downstream things don't hang forever.
  20. """
  21. def __init__(self, scope: Scope, receive: Receive):
  22. super().__init__(scope, receive)
  23. self._wrapped_rcv_disconnected = False
  24. self._wrapped_rcv_consumed = False
  25. self._wrapped_rc_stream = self.stream()
  26. async def wrapped_receive(self) -> Message:
  27. # wrapped_rcv state 1: disconnected
  28. if self._wrapped_rcv_disconnected:
  29. # we've already sent a disconnect to the downstream app
  30. # we don't need to wait to get another one
  31. # (although most ASGI servers will just keep sending it)
  32. return {"type": "http.disconnect"}
  33. # wrapped_rcv state 1: consumed but not yet disconnected
  34. if self._wrapped_rcv_consumed:
  35. # since the downstream app has consumed us all that is left
  36. # is to send it a disconnect
  37. if self._is_disconnected:
  38. # the middleware has already seen the disconnect
  39. # since we know the client is disconnected no need to wait
  40. # for the message
  41. self._wrapped_rcv_disconnected = True
  42. return {"type": "http.disconnect"}
  43. # we don't know yet if the client is disconnected or not
  44. # so we'll wait until we get that message
  45. msg = await self.receive()
  46. if msg["type"] != "http.disconnect": # pragma: no cover
  47. # at this point a disconnect is all that we should be receiving
  48. # if we get something else, things went wrong somewhere
  49. raise RuntimeError(f"Unexpected message received: {msg['type']}")
  50. self._wrapped_rcv_disconnected = True
  51. return msg
  52. # wrapped_rcv state 3: not yet consumed
  53. if getattr(self, "_body", None) is not None:
  54. # body() was called, we return it even if the client disconnected
  55. self._wrapped_rcv_consumed = True
  56. return {
  57. "type": "http.request",
  58. "body": self._body,
  59. "more_body": False,
  60. }
  61. elif self._stream_consumed:
  62. # stream() was called to completion
  63. # return an empty body so that downstream apps don't hang
  64. # waiting for a disconnect
  65. self._wrapped_rcv_consumed = True
  66. return {
  67. "type": "http.request",
  68. "body": b"",
  69. "more_body": False,
  70. }
  71. else:
  72. # body() was never called and stream() wasn't consumed
  73. try:
  74. stream = self.stream()
  75. chunk = await stream.__anext__()
  76. self._wrapped_rcv_consumed = self._stream_consumed
  77. return {
  78. "type": "http.request",
  79. "body": chunk,
  80. "more_body": not self._stream_consumed,
  81. }
  82. except ClientDisconnect:
  83. self._wrapped_rcv_disconnected = True
  84. return {"type": "http.disconnect"}
  85. class BaseHTTPMiddleware:
  86. def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
  87. self.app = app
  88. self.dispatch_func = self.dispatch if dispatch is None else dispatch
  89. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  90. if scope["type"] != "http":
  91. await self.app(scope, receive, send)
  92. return
  93. request = _CachedRequest(scope, receive)
  94. wrapped_receive = request.wrapped_receive
  95. response_sent = anyio.Event()
  96. app_exc: Exception | None = None
  97. exception_already_raised = False
  98. async def call_next(request: Request) -> Response:
  99. async def receive_or_disconnect() -> Message:
  100. if response_sent.is_set():
  101. return {"type": "http.disconnect"}
  102. async with anyio.create_task_group() as task_group:
  103. async def wrap(func: Callable[[], Awaitable[T]]) -> T:
  104. result = await func()
  105. task_group.cancel_scope.cancel()
  106. return result
  107. task_group.start_soon(wrap, response_sent.wait)
  108. message = await wrap(wrapped_receive)
  109. if response_sent.is_set():
  110. return {"type": "http.disconnect"}
  111. return message
  112. async def send_no_error(message: Message) -> None:
  113. try:
  114. await send_stream.send(message)
  115. except anyio.BrokenResourceError:
  116. # recv_stream has been closed, i.e. response_sent has been set.
  117. return
  118. async def coro() -> None:
  119. nonlocal app_exc
  120. with send_stream:
  121. try:
  122. await self.app(scope, receive_or_disconnect, send_no_error)
  123. except Exception as exc:
  124. app_exc = exc
  125. task_group.start_soon(coro)
  126. try:
  127. message = await recv_stream.receive()
  128. info = message.get("info", None)
  129. if message["type"] == "http.response.debug" and info is not None:
  130. message = await recv_stream.receive()
  131. except anyio.EndOfStream:
  132. if app_exc is not None:
  133. nonlocal exception_already_raised
  134. exception_already_raised = True
  135. raise app_exc
  136. raise RuntimeError("No response returned.")
  137. assert message["type"] == "http.response.start"
  138. async def body_stream() -> BodyStreamGenerator:
  139. async for message in recv_stream:
  140. if message["type"] == "http.response.pathsend":
  141. yield message
  142. break
  143. assert message["type"] == "http.response.body", f"Unexpected message: {message}"
  144. body = message.get("body", b"")
  145. if body:
  146. yield body
  147. if not message.get("more_body", False):
  148. break
  149. response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
  150. response.raw_headers = message["headers"]
  151. return response
  152. streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
  153. send_stream, recv_stream = streams
  154. with recv_stream, send_stream, collapse_excgroups():
  155. async with anyio.create_task_group() as task_group:
  156. response = await self.dispatch_func(request, call_next)
  157. await response(scope, wrapped_receive, send)
  158. response_sent.set()
  159. recv_stream.close()
  160. if app_exc is not None and not exception_already_raised:
  161. raise app_exc
  162. async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
  163. raise NotImplementedError() # pragma: no cover
  164. class _StreamingResponse(Response):
  165. def __init__(
  166. self,
  167. content: AsyncContentStream,
  168. status_code: int = 200,
  169. headers: Mapping[str, str] | None = None,
  170. media_type: str | None = None,
  171. info: Mapping[str, Any] | None = None,
  172. ) -> None:
  173. self.info = info
  174. self.body_iterator = content
  175. self.status_code = status_code
  176. self.media_type = media_type
  177. self.init_headers(headers)
  178. self.background = None
  179. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  180. if self.info is not None:
  181. await send({"type": "http.response.debug", "info": self.info})
  182. await send(
  183. {
  184. "type": "http.response.start",
  185. "status": self.status_code,
  186. "headers": self.raw_headers,
  187. }
  188. )
  189. should_close_body = True
  190. async for chunk in self.body_iterator:
  191. if isinstance(chunk, dict):
  192. # We got an ASGI message which is not response body (eg: pathsend)
  193. should_close_body = False
  194. await send(chunk)
  195. continue
  196. await send({"type": "http.response.body", "body": chunk, "more_body": True})
  197. if should_close_body:
  198. await send({"type": "http.response.body", "body": b"", "more_body": False})
  199. if self.background:
  200. await self.background()