|
- from __future__ import annotations
-
- from typing import Callable
-
- from starlette.authentication import (
- AuthCredentials,
- AuthenticationBackend,
- AuthenticationError,
- UnauthenticatedUser,
- )
- from starlette.requests import HTTPConnection
- from starlette.responses import PlainTextResponse, Response
- from starlette.types import ASGIApp, Receive, Scope, Send
-
-
- class AuthenticationMiddleware:
- def __init__(
- self,
- app: ASGIApp,
- backend: AuthenticationBackend,
- on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
- ) -> None:
- self.app = app
- self.backend = backend
- self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
- on_error if on_error is not None else self.default_on_error
- )
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] not in ["http", "websocket"]:
- await self.app(scope, receive, send)
- return
-
- conn = HTTPConnection(scope)
- try:
- auth_result = await self.backend.authenticate(conn)
- except AuthenticationError as exc:
- response = self.on_error(conn, exc)
- if scope["type"] == "websocket":
- await send({"type": "websocket.close", "code": 1000})
- else:
- await response(scope, receive, send)
- return
-
- if auth_result is None:
- auth_result = AuthCredentials(), UnauthenticatedUser()
- scope["auth"], scope["user"] = auth_result
- await self.app(scope, receive, send)
-
- @staticmethod
- def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
- return PlainTextResponse(str(exc), status_code=400)
|