|
- from __future__ import annotations
-
- import functools
- import inspect
- import sys
- from collections.abc import Sequence
- from typing import Any, Callable
- from urllib.parse import urlencode
-
- if sys.version_info >= (3, 10): # pragma: no cover
- from typing import ParamSpec
- else: # pragma: no cover
- from typing_extensions import ParamSpec
-
- from starlette._utils import is_async_callable
- from starlette.exceptions import HTTPException
- from starlette.requests import HTTPConnection, Request
- from starlette.responses import RedirectResponse
- from starlette.websockets import WebSocket
-
- _P = ParamSpec("_P")
-
-
- def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
- for scope in scopes:
- if scope not in conn.auth.scopes:
- return False
- return True
-
-
- def requires(
- scopes: str | Sequence[str],
- status_code: int = 403,
- redirect: str | None = None,
- ) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
- scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
-
- def decorator(
- func: Callable[_P, Any],
- ) -> Callable[_P, Any]:
- sig = inspect.signature(func)
- for idx, parameter in enumerate(sig.parameters.values()):
- if parameter.name == "request" or parameter.name == "websocket":
- type_ = parameter.name
- break
- else:
- raise Exception(f'No "request" or "websocket" argument on function "{func}"')
-
- if type_ == "websocket":
- # Handle websocket functions. (Always async)
- @functools.wraps(func)
- async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
- assert isinstance(websocket, WebSocket)
-
- if not has_required_scope(websocket, scopes_list):
- await websocket.close()
- else:
- await func(*args, **kwargs)
-
- return websocket_wrapper
-
- elif is_async_callable(func):
- # Handle async request/response functions.
- @functools.wraps(func)
- async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
-
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- orig_request_qparam = urlencode({"next": str(request.url)})
- next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
- return RedirectResponse(url=next_url, status_code=303)
- raise HTTPException(status_code=status_code)
- return await func(*args, **kwargs)
-
- return async_wrapper
-
- else:
- # Handle sync request/response functions.
- @functools.wraps(func)
- def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
-
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- orig_request_qparam = urlencode({"next": str(request.url)})
- next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
- return RedirectResponse(url=next_url, status_code=303)
- raise HTTPException(status_code=status_code)
- return func(*args, **kwargs)
-
- return sync_wrapper
-
- return decorator
-
-
- class AuthenticationError(Exception):
- pass
-
-
- class AuthenticationBackend:
- async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
- raise NotImplementedError() # pragma: no cover
-
-
- class AuthCredentials:
- def __init__(self, scopes: Sequence[str] | None = None):
- self.scopes = [] if scopes is None else list(scopes)
-
-
- class BaseUser:
- @property
- def is_authenticated(self) -> bool:
- raise NotImplementedError() # pragma: no cover
-
- @property
- def display_name(self) -> str:
- raise NotImplementedError() # pragma: no cover
-
- @property
- def identity(self) -> str:
- raise NotImplementedError() # pragma: no cover
-
-
- class SimpleUser(BaseUser):
- def __init__(self, username: str) -> None:
- self.username = username
-
- @property
- def is_authenticated(self) -> bool:
- return True
-
- @property
- def display_name(self) -> str:
- return self.username
-
-
- class UnauthenticatedUser(BaseUser):
- @property
- def is_authenticated(self) -> bool:
- return False
-
- @property
- def display_name(self) -> str:
- return ""
|