You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

99 line
3.5 KiB

  1. import asyncio
  2. import http
  3. import typing
  4. from starlette.concurrency import run_in_threadpool
  5. from starlette.requests import Request
  6. from starlette.responses import PlainTextResponse, Response
  7. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  8. class HTTPException(Exception):
  9. def __init__(self, status_code: int, detail: str = None) -> None:
  10. if detail is None:
  11. detail = http.HTTPStatus(status_code).phrase
  12. self.status_code = status_code
  13. self.detail = detail
  14. def __repr__(self) -> str:
  15. class_name = self.__class__.__name__
  16. return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
  17. class ExceptionMiddleware:
  18. def __init__(
  19. self, app: ASGIApp, handlers: dict = None, debug: bool = False
  20. ) -> None:
  21. self.app = app
  22. self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
  23. self._status_handlers: typing.Dict[int, typing.Callable] = {}
  24. self._exception_handlers: typing.Dict[
  25. typing.Type[Exception], typing.Callable
  26. ] = {HTTPException: self.http_exception}
  27. if handlers is not None:
  28. for key, value in handlers.items():
  29. self.add_exception_handler(key, value)
  30. def add_exception_handler(
  31. self,
  32. exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
  33. handler: typing.Callable,
  34. ) -> None:
  35. if isinstance(exc_class_or_status_code, int):
  36. self._status_handlers[exc_class_or_status_code] = handler
  37. else:
  38. assert issubclass(exc_class_or_status_code, Exception)
  39. self._exception_handlers[exc_class_or_status_code] = handler
  40. def _lookup_exception_handler(
  41. self, exc: Exception
  42. ) -> typing.Optional[typing.Callable]:
  43. for cls in type(exc).__mro__:
  44. if cls in self._exception_handlers:
  45. return self._exception_handlers[cls]
  46. return None
  47. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  48. if scope["type"] != "http":
  49. await self.app(scope, receive, send)
  50. return
  51. response_started = False
  52. async def sender(message: Message) -> None:
  53. nonlocal response_started
  54. if message["type"] == "http.response.start":
  55. response_started = True
  56. await send(message)
  57. try:
  58. await self.app(scope, receive, sender)
  59. except Exception as exc:
  60. handler = None
  61. if isinstance(exc, HTTPException):
  62. handler = self._status_handlers.get(exc.status_code)
  63. if handler is None:
  64. handler = self._lookup_exception_handler(exc)
  65. if handler is None:
  66. raise exc
  67. if response_started:
  68. msg = "Caught handled exception, but response already started."
  69. raise RuntimeError(msg) from exc
  70. request = Request(scope, receive=receive)
  71. if asyncio.iscoroutinefunction(handler):
  72. response = await handler(request, exc)
  73. else:
  74. response = await run_in_threadpool(handler, request, exc)
  75. await response(scope, receive, sender)
  76. def http_exception(self, request: Request, exc: HTTPException) -> Response:
  77. if exc.status_code in {204, 304}:
  78. return Response(b"", status_code=exc.status_code)
  79. return PlainTextResponse(exc.detail, status_code=exc.status_code)