You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

166 line
5.8 KiB

  1. import binascii
  2. from base64 import b64decode
  3. from typing import Optional
  4. from fastapi.exceptions import HTTPException
  5. from fastapi.openapi.models import HTTPBase as HTTPBaseModel
  6. from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
  7. from fastapi.security.base import SecurityBase
  8. from fastapi.security.utils import get_authorization_scheme_param
  9. from pydantic import BaseModel
  10. from starlette.requests import Request
  11. from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
  12. class HTTPBasicCredentials(BaseModel):
  13. username: str
  14. password: str
  15. class HTTPAuthorizationCredentials(BaseModel):
  16. scheme: str
  17. credentials: str
  18. class HTTPBase(SecurityBase):
  19. def __init__(
  20. self,
  21. *,
  22. scheme: str,
  23. scheme_name: Optional[str] = None,
  24. description: Optional[str] = None,
  25. auto_error: bool = True,
  26. ):
  27. self.model = HTTPBaseModel(scheme=scheme, description=description)
  28. self.scheme_name = scheme_name or self.__class__.__name__
  29. self.auto_error = auto_error
  30. async def __call__(
  31. self, request: Request
  32. ) -> Optional[HTTPAuthorizationCredentials]:
  33. authorization: str = request.headers.get("Authorization")
  34. scheme, credentials = get_authorization_scheme_param(authorization)
  35. if not (authorization and scheme and credentials):
  36. if self.auto_error:
  37. raise HTTPException(
  38. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  39. )
  40. else:
  41. return None
  42. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  43. class HTTPBasic(HTTPBase):
  44. def __init__(
  45. self,
  46. *,
  47. scheme_name: Optional[str] = None,
  48. realm: Optional[str] = None,
  49. description: Optional[str] = None,
  50. auto_error: bool = True,
  51. ):
  52. self.model = HTTPBaseModel(scheme="basic", description=description)
  53. self.scheme_name = scheme_name or self.__class__.__name__
  54. self.realm = realm
  55. self.auto_error = auto_error
  56. async def __call__( # type: ignore
  57. self, request: Request
  58. ) -> Optional[HTTPBasicCredentials]:
  59. authorization: str = request.headers.get("Authorization")
  60. scheme, param = get_authorization_scheme_param(authorization)
  61. if self.realm:
  62. unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
  63. else:
  64. unauthorized_headers = {"WWW-Authenticate": "Basic"}
  65. invalid_user_credentials_exc = HTTPException(
  66. status_code=HTTP_401_UNAUTHORIZED,
  67. detail="Invalid authentication credentials",
  68. headers=unauthorized_headers,
  69. )
  70. if not authorization or scheme.lower() != "basic":
  71. if self.auto_error:
  72. raise HTTPException(
  73. status_code=HTTP_401_UNAUTHORIZED,
  74. detail="Not authenticated",
  75. headers=unauthorized_headers,
  76. )
  77. else:
  78. return None
  79. try:
  80. data = b64decode(param).decode("ascii")
  81. except (ValueError, UnicodeDecodeError, binascii.Error):
  82. raise invalid_user_credentials_exc
  83. username, separator, password = data.partition(":")
  84. if not separator:
  85. raise invalid_user_credentials_exc
  86. return HTTPBasicCredentials(username=username, password=password)
  87. class HTTPBearer(HTTPBase):
  88. def __init__(
  89. self,
  90. *,
  91. bearerFormat: Optional[str] = None,
  92. scheme_name: Optional[str] = None,
  93. description: Optional[str] = None,
  94. auto_error: bool = True,
  95. ):
  96. self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
  97. self.scheme_name = scheme_name or self.__class__.__name__
  98. self.auto_error = auto_error
  99. async def __call__(
  100. self, request: Request
  101. ) -> Optional[HTTPAuthorizationCredentials]:
  102. authorization: str = request.headers.get("Authorization")
  103. scheme, credentials = get_authorization_scheme_param(authorization)
  104. if not (authorization and scheme and credentials):
  105. if self.auto_error:
  106. raise HTTPException(
  107. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  108. )
  109. else:
  110. return None
  111. if scheme.lower() != "bearer":
  112. if self.auto_error:
  113. raise HTTPException(
  114. status_code=HTTP_403_FORBIDDEN,
  115. detail="Invalid authentication credentials",
  116. )
  117. else:
  118. return None
  119. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  120. class HTTPDigest(HTTPBase):
  121. def __init__(
  122. self,
  123. *,
  124. scheme_name: Optional[str] = None,
  125. description: Optional[str] = None,
  126. auto_error: bool = True,
  127. ):
  128. self.model = HTTPBaseModel(scheme="digest", description=description)
  129. self.scheme_name = scheme_name or self.__class__.__name__
  130. self.auto_error = auto_error
  131. async def __call__(
  132. self, request: Request
  133. ) -> Optional[HTTPAuthorizationCredentials]:
  134. authorization: str = request.headers.get("Authorization")
  135. scheme, credentials = get_authorization_scheme_param(authorization)
  136. if not (authorization and scheme and credentials):
  137. if self.auto_error:
  138. raise HTTPException(
  139. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  140. )
  141. else:
  142. return None
  143. if scheme.lower() != "digest":
  144. raise HTTPException(
  145. status_code=HTTP_403_FORBIDDEN,
  146. detail="Invalid authentication credentials",
  147. )
  148. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)