You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

71 lines
2.5 KiB

  1. import typing
  2. import anyio
  3. from starlette.requests import Request
  4. from starlette.responses import Response, StreamingResponse
  5. from starlette.types import ASGIApp, Receive, Scope, Send
  6. RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
  7. DispatchFunction = typing.Callable[
  8. [Request, RequestResponseEndpoint], typing.Awaitable[Response]
  9. ]
  10. class BaseHTTPMiddleware:
  11. def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
  12. self.app = app
  13. self.dispatch_func = self.dispatch if dispatch is None else dispatch
  14. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  15. if scope["type"] != "http":
  16. await self.app(scope, receive, send)
  17. return
  18. async def call_next(request: Request) -> Response:
  19. app_exc: typing.Optional[Exception] = None
  20. send_stream, recv_stream = anyio.create_memory_object_stream()
  21. async def coro() -> None:
  22. nonlocal app_exc
  23. async with send_stream:
  24. try:
  25. await self.app(scope, request.receive, send_stream.send)
  26. except Exception as exc:
  27. app_exc = exc
  28. task_group.start_soon(coro)
  29. try:
  30. message = await recv_stream.receive()
  31. except anyio.EndOfStream:
  32. if app_exc is not None:
  33. raise app_exc
  34. raise RuntimeError("No response returned.")
  35. assert message["type"] == "http.response.start"
  36. async def body_stream() -> typing.AsyncGenerator[bytes, None]:
  37. async with recv_stream:
  38. async for message in recv_stream:
  39. assert message["type"] == "http.response.body"
  40. yield message.get("body", b"")
  41. response = StreamingResponse(
  42. status_code=message["status"], content=body_stream()
  43. )
  44. response.raw_headers = message["headers"]
  45. return response
  46. async with anyio.create_task_group() as task_group:
  47. request = Request(scope, receive=receive)
  48. response = await self.dispatch_func(request, call_next)
  49. await response(scope, receive, send)
  50. task_group.cancel_scope.cancel()
  51. async def dispatch(
  52. self, request: Request, call_next: RequestResponseEndpoint
  53. ) -> Response:
  54. raise NotImplementedError() # pragma: no cover