|
- import binascii
- from base64 import b64decode
- from typing import Optional
-
- from fastapi.exceptions import HTTPException
- from fastapi.openapi.models import HTTPBase as HTTPBaseModel
- from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
- from fastapi.security.base import SecurityBase
- from fastapi.security.utils import get_authorization_scheme_param
- from pydantic import BaseModel
- from starlette.requests import Request
- from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
-
-
- class HTTPBasicCredentials(BaseModel):
- username: str
- password: str
-
-
- class HTTPAuthorizationCredentials(BaseModel):
- scheme: str
- credentials: str
-
-
- class HTTPBase(SecurityBase):
- def __init__(
- self,
- *,
- scheme: str,
- scheme_name: Optional[str] = None,
- description: Optional[str] = None,
- auto_error: bool = True,
- ):
- self.model = HTTPBaseModel(scheme=scheme, description=description)
- self.scheme_name = scheme_name or self.__class__.__name__
- self.auto_error = auto_error
-
- async def __call__(
- self, request: Request
- ) -> Optional[HTTPAuthorizationCredentials]:
- authorization: str = request.headers.get("Authorization")
- scheme, credentials = get_authorization_scheme_param(authorization)
- if not (authorization and scheme and credentials):
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
- )
- else:
- return None
- return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
-
-
- class HTTPBasic(HTTPBase):
- def __init__(
- self,
- *,
- scheme_name: Optional[str] = None,
- realm: Optional[str] = None,
- description: Optional[str] = None,
- auto_error: bool = True,
- ):
- self.model = HTTPBaseModel(scheme="basic", description=description)
- self.scheme_name = scheme_name or self.__class__.__name__
- self.realm = realm
- self.auto_error = auto_error
-
- async def __call__( # type: ignore
- self, request: Request
- ) -> Optional[HTTPBasicCredentials]:
- authorization: str = request.headers.get("Authorization")
- scheme, param = get_authorization_scheme_param(authorization)
- if self.realm:
- unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
- else:
- unauthorized_headers = {"WWW-Authenticate": "Basic"}
- invalid_user_credentials_exc = HTTPException(
- status_code=HTTP_401_UNAUTHORIZED,
- detail="Invalid authentication credentials",
- headers=unauthorized_headers,
- )
- if not authorization or scheme.lower() != "basic":
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_401_UNAUTHORIZED,
- detail="Not authenticated",
- headers=unauthorized_headers,
- )
- else:
- return None
- try:
- data = b64decode(param).decode("ascii")
- except (ValueError, UnicodeDecodeError, binascii.Error):
- raise invalid_user_credentials_exc
- username, separator, password = data.partition(":")
- if not separator:
- raise invalid_user_credentials_exc
- return HTTPBasicCredentials(username=username, password=password)
-
-
- class HTTPBearer(HTTPBase):
- def __init__(
- self,
- *,
- bearerFormat: Optional[str] = None,
- scheme_name: Optional[str] = None,
- description: Optional[str] = None,
- auto_error: bool = True,
- ):
- self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
- self.scheme_name = scheme_name or self.__class__.__name__
- self.auto_error = auto_error
-
- async def __call__(
- self, request: Request
- ) -> Optional[HTTPAuthorizationCredentials]:
- authorization: str = request.headers.get("Authorization")
- scheme, credentials = get_authorization_scheme_param(authorization)
- if not (authorization and scheme and credentials):
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
- )
- else:
- return None
- if scheme.lower() != "bearer":
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN,
- detail="Invalid authentication credentials",
- )
- else:
- return None
- return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
-
-
- class HTTPDigest(HTTPBase):
- def __init__(
- self,
- *,
- scheme_name: Optional[str] = None,
- description: Optional[str] = None,
- auto_error: bool = True,
- ):
- self.model = HTTPBaseModel(scheme="digest", description=description)
- self.scheme_name = scheme_name or self.__class__.__name__
- self.auto_error = auto_error
-
- async def __call__(
- self, request: Request
- ) -> Optional[HTTPAuthorizationCredentials]:
- authorization: str = request.headers.get("Authorization")
- scheme, credentials = get_authorization_scheme_param(authorization)
- if not (authorization and scheme and credentials):
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
- )
- else:
- return None
- if scheme.lower() != "digest":
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN,
- detail="Invalid authentication credentials",
- )
- return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|