|
- import asyncio
- import functools
- import inspect
- import typing
-
- from starlette.exceptions import HTTPException
- from starlette.requests import HTTPConnection, Request
- from starlette.responses import RedirectResponse, Response
- from starlette.websockets import WebSocket
-
-
- def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
- for scope in scopes:
- if scope not in conn.auth.scopes:
- return False
- return True
-
-
- def requires(
- scopes: typing.Union[str, typing.Sequence[str]],
- status_code: int = 403,
- redirect: str = None,
- ) -> typing.Callable:
- scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
-
- def decorator(func: typing.Callable) -> typing.Callable:
- sig = inspect.signature(func)
- for idx, parameter in enumerate(sig.parameters.values()):
- if parameter.name == "request" or parameter.name == "websocket":
- type_ = parameter.name
- break
- else:
- raise Exception(
- f'No "request" or "websocket" argument on function "{func}"'
- )
-
- if type_ == "websocket":
- # Handle websocket functions. (Always async)
- @functools.wraps(func)
- async def websocket_wrapper(
- *args: typing.Any, **kwargs: typing.Any
- ) -> None:
- websocket = kwargs.get(
- "websocket", args[idx] if idx < len(args) else None
- )
- assert isinstance(websocket, WebSocket)
-
- if not has_required_scope(websocket, scopes_list):
- await websocket.close()
- else:
- await func(*args, **kwargs)
-
- return websocket_wrapper
-
- elif asyncio.iscoroutinefunction(func):
- # Handle async request/response functions.
- @functools.wraps(func)
- async def async_wrapper(
- *args: typing.Any, **kwargs: typing.Any
- ) -> Response:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
-
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- return RedirectResponse(
- url=request.url_for(redirect), status_code=303
- )
- raise HTTPException(status_code=status_code)
- return await func(*args, **kwargs)
-
- return async_wrapper
-
- else:
- # Handle sync request/response functions.
- @functools.wraps(func)
- def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
-
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- return RedirectResponse(
- url=request.url_for(redirect), status_code=303
- )
- raise HTTPException(status_code=status_code)
- return func(*args, **kwargs)
-
- return sync_wrapper
-
- return decorator
-
-
- class AuthenticationError(Exception):
- pass
-
-
- class AuthenticationBackend:
- async def authenticate(
- self, conn: HTTPConnection
- ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
- raise NotImplementedError() # pragma: no cover
-
-
- class AuthCredentials:
- def __init__(self, scopes: typing.Sequence[str] = None):
- self.scopes = [] if scopes is None else list(scopes)
-
-
- class BaseUser:
- @property
- def is_authenticated(self) -> bool:
- raise NotImplementedError() # pragma: no cover
-
- @property
- def display_name(self) -> str:
- raise NotImplementedError() # pragma: no cover
-
- @property
- def identity(self) -> str:
- raise NotImplementedError() # pragma: no cover
-
-
- class SimpleUser(BaseUser):
- def __init__(self, username: str) -> None:
- self.username = username
-
- @property
- def is_authenticated(self) -> bool:
- return True
-
- @property
- def display_name(self) -> str:
- return self.username
-
-
- class UnauthenticatedUser(BaseUser):
- @property
- def is_authenticated(self) -> bool:
- return False
-
- @property
- def display_name(self) -> str:
- return ""
|