選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。
 
 
 
 

71 行
2.5 KiB

  1. import typing
  2. import anyio
  3. from starlette.requests import Request
  4. from starlette.responses import Response, StreamingResponse
  5. from starlette.types import ASGIApp, Receive, Scope, Send
  6. RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
  7. DispatchFunction = typing.Callable[
  8. [Request, RequestResponseEndpoint], typing.Awaitable[Response]
  9. ]
  10. class BaseHTTPMiddleware:
  11. def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
  12. self.app = app
  13. self.dispatch_func = self.dispatch if dispatch is None else dispatch
  14. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  15. if scope["type"] != "http":
  16. await self.app(scope, receive, send)
  17. return
  18. async def call_next(request: Request) -> Response:
  19. app_exc: typing.Optional[Exception] = None
  20. send_stream, recv_stream = anyio.create_memory_object_stream()
  21. async def coro() -> None:
  22. nonlocal app_exc
  23. async with send_stream:
  24. try:
  25. await self.app(scope, request.receive, send_stream.send)
  26. except Exception as exc:
  27. app_exc = exc
  28. task_group.start_soon(coro)
  29. try:
  30. message = await recv_stream.receive()
  31. except anyio.EndOfStream:
  32. if app_exc is not None:
  33. raise app_exc
  34. raise RuntimeError("No response returned.")
  35. assert message["type"] == "http.response.start"
  36. async def body_stream() -> typing.AsyncGenerator[bytes, None]:
  37. async with recv_stream:
  38. async for message in recv_stream:
  39. assert message["type"] == "http.response.body"
  40. yield message.get("body", b"")
  41. response = StreamingResponse(
  42. status_code=message["status"], content=body_stream()
  43. )
  44. response.raw_headers = message["headers"]
  45. return response
  46. async with anyio.create_task_group() as task_group:
  47. request = Request(scope, receive=receive)
  48. response = await self.dispatch_func(request, call_next)
  49. await response(scope, receive, send)
  50. task_group.cancel_scope.cancel()
  51. async def dispatch(
  52. self, request: Request, call_next: RequestResponseEndpoint
  53. ) -> Response:
  54. raise NotImplementedError() # pragma: no cover