You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

53 line
1.7 KiB

  1. from __future__ import annotations
  2. from typing import Callable
  3. from starlette.authentication import (
  4. AuthCredentials,
  5. AuthenticationBackend,
  6. AuthenticationError,
  7. UnauthenticatedUser,
  8. )
  9. from starlette.requests import HTTPConnection
  10. from starlette.responses import PlainTextResponse, Response
  11. from starlette.types import ASGIApp, Receive, Scope, Send
  12. class AuthenticationMiddleware:
  13. def __init__(
  14. self,
  15. app: ASGIApp,
  16. backend: AuthenticationBackend,
  17. on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
  18. ) -> None:
  19. self.app = app
  20. self.backend = backend
  21. self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
  22. on_error if on_error is not None else self.default_on_error
  23. )
  24. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  25. if scope["type"] not in ["http", "websocket"]:
  26. await self.app(scope, receive, send)
  27. return
  28. conn = HTTPConnection(scope)
  29. try:
  30. auth_result = await self.backend.authenticate(conn)
  31. except AuthenticationError as exc:
  32. response = self.on_error(conn, exc)
  33. if scope["type"] == "websocket":
  34. await send({"type": "websocket.close", "code": 1000})
  35. else:
  36. await response(scope, receive, send)
  37. return
  38. if auth_result is None:
  39. auth_result = AuthCredentials(), UnauthenticatedUser()
  40. scope["auth"], scope["user"] = auth_result
  41. await self.app(scope, receive, send)
  42. @staticmethod
  43. def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
  44. return PlainTextResponse(str(exc), status_code=400)