You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

93 lines
2.7 KiB

  1. from typing import Optional
  2. from fastapi.openapi.models import APIKey, APIKeyIn
  3. from fastapi.security.base import SecurityBase
  4. from starlette.exceptions import HTTPException
  5. from starlette.requests import Request
  6. from starlette.status import HTTP_403_FORBIDDEN
  7. class APIKeyBase(SecurityBase):
  8. pass
  9. class APIKeyQuery(APIKeyBase):
  10. def __init__(
  11. self,
  12. *,
  13. name: str,
  14. scheme_name: Optional[str] = None,
  15. description: Optional[str] = None,
  16. auto_error: bool = True
  17. ):
  18. self.model: APIKey = APIKey(
  19. **{"in": APIKeyIn.query}, name=name, description=description
  20. )
  21. self.scheme_name = scheme_name or self.__class__.__name__
  22. self.auto_error = auto_error
  23. async def __call__(self, request: Request) -> Optional[str]:
  24. api_key: str = request.query_params.get(self.model.name)
  25. if not api_key:
  26. if self.auto_error:
  27. raise HTTPException(
  28. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  29. )
  30. else:
  31. return None
  32. return api_key
  33. class APIKeyHeader(APIKeyBase):
  34. def __init__(
  35. self,
  36. *,
  37. name: str,
  38. scheme_name: Optional[str] = None,
  39. description: Optional[str] = None,
  40. auto_error: bool = True
  41. ):
  42. self.model: APIKey = APIKey(
  43. **{"in": APIKeyIn.header}, name=name, description=description
  44. )
  45. self.scheme_name = scheme_name or self.__class__.__name__
  46. self.auto_error = auto_error
  47. async def __call__(self, request: Request) -> Optional[str]:
  48. api_key: str = request.headers.get(self.model.name)
  49. if not api_key:
  50. if self.auto_error:
  51. raise HTTPException(
  52. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  53. )
  54. else:
  55. return None
  56. return api_key
  57. class APIKeyCookie(APIKeyBase):
  58. def __init__(
  59. self,
  60. *,
  61. name: str,
  62. scheme_name: Optional[str] = None,
  63. description: Optional[str] = None,
  64. auto_error: bool = True
  65. ):
  66. self.model: APIKey = APIKey(
  67. **{"in": APIKeyIn.cookie}, name=name, description=description
  68. )
  69. self.scheme_name = scheme_name or self.__class__.__name__
  70. self.auto_error = auto_error
  71. async def __call__(self, request: Request) -> Optional[str]:
  72. api_key = request.cookies.get(self.model.name)
  73. if not api_key:
  74. if self.auto_error:
  75. raise HTTPException(
  76. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  77. )
  78. else:
  79. return None
  80. return api_key