You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

172 lines
5.4 KiB

  1. from datetime import datetime, timedelta
  2. from typing import Optional
  3. import jwt
  4. from fastapi import Depends, Header
  5. from jwt import PyJWTError
  6. from starlette.exceptions import HTTPException
  7. from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
  8. from fastapi.security import OAuth2PasswordRequestForm, OAuth2
  9. from crud.user import get_user
  10. from models.token import TokenPayload
  11. from models.user import User
  12. from starlette.requests import Request
  13. from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
  14. from .config import JWT_TOKEN_PREFIX, SECRET_KEY
  15. ALGORITHM = "HS256"
  16. access_token_jwt_subject = "access"
  17. def _get_authorization_token(authorization: str = Header(...)):
  18. token_prefix, token = authorization.split(" ")
  19. if token_prefix != JWT_TOKEN_PREFIX:
  20. raise HTTPException(
  21. status_code=HTTP_403_FORBIDDEN, detail="Invalid authorization type"
  22. )
  23. return token
  24. # async def _get_current_user(
  25. # db: DataBase = Depends(get_database), token: str = Depends(_get_authorization_token)
  26. # ) -> User:
  27. # try:
  28. # payload = jwt.decode(token, str(SECRET_KEY), algorithms=[ALGORITHM])
  29. # token_data = TokenPayload(**payload)
  30. # except PyJWTError:
  31. # raise HTTPException(
  32. # status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
  33. # )
  34. # async with db.pool.acquire() as conn:
  35. # dbuser = await get_user(conn, token_data.username)
  36. # if not dbuser:
  37. # raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found")
  38. # user = User(**dbuser.dict(), token=token)
  39. # return user
  40. def _get_authorization_token_optional(authorization: str = Header(None)):
  41. if authorization:
  42. return _get_authorization_token(authorization)
  43. return ""
  44. # async def _get_current_user_optional(
  45. # db: DataBase = Depends(get_database),
  46. # token: str = Depends(_get_authorization_token_optional),
  47. # ) -> Optional[User]:
  48. # if token:
  49. # return await _get_current_user(db, token)
  50. # return None
  51. def get_current_user_authorizer(*, required: bool = True):
  52. if required:
  53. return _get_current_user
  54. else:
  55. return _get_current_user_optional
  56. def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None):
  57. to_encode = data.copy()
  58. if expires_delta:
  59. expire = datetime.utcnow() + expires_delta
  60. else:
  61. expire = datetime.utcnow() + timedelta(minutes=15)
  62. to_encode.update({"exp": expire, "sub": access_token_jwt_subject})
  63. encoded_jwt = jwt.encode(to_encode, str(SECRET_KEY), algorithm=ALGORITHM)
  64. return encoded_jwt
  65. ##########################
  66. class OAuth2PasswordBearerCookie(OAuth2):
  67. def __init__(
  68. self,
  69. tokenUrl: str,
  70. scheme_name: str = None,
  71. scopes: dict = None,
  72. auto_error: bool = True,
  73. ):
  74. if not scopes:
  75. scopes = {}
  76. flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
  77. super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
  78. async def __call__(self, request: Request) -> Optional[str]:
  79. header_authorization: str = request.headers.get("Authorization")
  80. cookie_authorization: str = request.cookies.get("Authorization")
  81. header_scheme, header_param = get_authorization_scheme_param(
  82. header_authorization
  83. )
  84. cookie_scheme, cookie_param = get_authorization_scheme_param(
  85. cookie_authorization
  86. )
  87. if header_scheme.lower() == "bearer":
  88. authorization = True
  89. scheme = header_scheme
  90. param = header_param
  91. elif cookie_scheme.lower() == "bearer":
  92. authorization = True
  93. scheme = cookie_scheme
  94. param = cookie_param
  95. else:
  96. authorization = False
  97. if not authorization or scheme.lower() != "bearer":
  98. if self.auto_error:
  99. raise HTTPException(
  100. status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
  101. )
  102. else:
  103. return None
  104. return param
  105. oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token")
  106. def create_access_token(*, data: dict, expires_delta: timedelta = None):
  107. to_encode = data.copy()
  108. if expires_delta:
  109. expire = datetime.utcnow() + expires_delta
  110. else:
  111. expire = datetime.utcnow() + timedelta(minutes=15)
  112. to_encode.update({"exp": expire})
  113. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  114. return encoded_jwt
  115. async def get_current_user(token: str = Depends(oauth2_scheme)):
  116. credentials_exception = HTTPException(
  117. status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
  118. )
  119. try:
  120. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  121. username: str = payload.get("sub")
  122. if username is None:
  123. raise credentials_exception
  124. token_data = TokenData(username=username)
  125. except PyJWTError:
  126. raise credentials_exception
  127. user = get_user(fake_users_db, username=token_data.username)
  128. if user is None:
  129. raise credentials_exception
  130. return user
  131. async def get_current_active_user(current_user: User = Depends(get_current_user)):
  132. if current_user.disabled:
  133. raise HTTPException(status_code=400, detail="Inactive user")
  134. return current_user