您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

53 行
1.7 KiB

  1. import typing
  2. from starlette.authentication import (
  3. AuthCredentials,
  4. AuthenticationBackend,
  5. AuthenticationError,
  6. UnauthenticatedUser,
  7. )
  8. from starlette.requests import HTTPConnection
  9. from starlette.responses import PlainTextResponse, Response
  10. from starlette.types import ASGIApp, Receive, Scope, Send
  11. class AuthenticationMiddleware:
  12. def __init__(
  13. self,
  14. app: ASGIApp,
  15. backend: AuthenticationBackend,
  16. on_error: typing.Callable[
  17. [HTTPConnection, AuthenticationError], Response
  18. ] = None,
  19. ) -> None:
  20. self.app = app
  21. self.backend = backend
  22. self.on_error: typing.Callable[
  23. [HTTPConnection, AuthenticationError], Response
  24. ] = (on_error if on_error is not None else self.default_on_error)
  25. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  26. if scope["type"] not in ["http", "websocket"]:
  27. await self.app(scope, receive, send)
  28. return
  29. conn = HTTPConnection(scope)
  30. try:
  31. auth_result = await self.backend.authenticate(conn)
  32. except AuthenticationError as exc:
  33. response = self.on_error(conn, exc)
  34. if scope["type"] == "websocket":
  35. await send({"type": "websocket.close", "code": 1000})
  36. else:
  37. await response(scope, receive, send)
  38. return
  39. if auth_result is None:
  40. auth_result = AuthCredentials(), UnauthenticatedUser()
  41. scope["auth"], scope["user"] = auth_result
  42. await self.app(scope, receive, send)
  43. @staticmethod
  44. def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
  45. return PlainTextResponse(str(exc), status_code=400)