from datetime import datetime, timedelta from typing import Optional import jwt from fastapi import Depends, Header from jwt import PyJWTError from starlette.exceptions import HTTPException from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND from fastapi.security import OAuth2PasswordRequestForm, OAuth2 from crud.user import get_user from models.token import TokenPayload from models.user import User from starlette.requests import Request from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from .config import JWT_TOKEN_PREFIX, SECRET_KEY ALGORITHM = "HS256" access_token_jwt_subject = "access" def _get_authorization_token(authorization: str = Header(...)): token_prefix, token = authorization.split(" ") if token_prefix != JWT_TOKEN_PREFIX: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid authorization type" ) return token # async def _get_current_user( # db: DataBase = Depends(get_database), token: str = Depends(_get_authorization_token) # ) -> User: # try: # payload = jwt.decode(token, str(SECRET_KEY), algorithms=[ALGORITHM]) # token_data = TokenPayload(**payload) # except PyJWTError: # raise HTTPException( # status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" # ) # async with db.pool.acquire() as conn: # dbuser = await get_user(conn, token_data.username) # if not dbuser: # raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found") # user = User(**dbuser.dict(), token=token) # return user def _get_authorization_token_optional(authorization: str = Header(None)): if authorization: return _get_authorization_token(authorization) return "" # async def _get_current_user_optional( # db: DataBase = Depends(get_database), # token: str = Depends(_get_authorization_token_optional), # ) -> Optional[User]: # if token: # return await _get_current_user(db, token) # return None def get_current_user_authorizer(*, required: bool = True): if required: return _get_current_user else: return _get_current_user_optional def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire, "sub": access_token_jwt_subject}) encoded_jwt = jwt.encode(to_encode, str(SECRET_KEY), algorithm=ALGORITHM) return encoded_jwt ########################## class OAuth2PasswordBearerCookie(OAuth2): def __init__( self, tokenUrl: str, scheme_name: str = None, scopes: dict = None, auto_error: bool = True, ): if not scopes: scopes = {} flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error) async def __call__(self, request: Request) -> Optional[str]: header_authorization: str = request.headers.get("Authorization") cookie_authorization: str = request.cookies.get("Authorization") header_scheme, header_param = get_authorization_scheme_param( header_authorization ) cookie_scheme, cookie_param = get_authorization_scheme_param( cookie_authorization ) if header_scheme.lower() == "bearer": authorization = True scheme = header_scheme param = header_param elif cookie_scheme.lower() == "bearer": authorization = True scheme = cookie_scheme param = cookie_param else: authorization = False if not authorization or scheme.lower() != "bearer": if self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) else: return None return param oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token") def create_access_token(*, data: dict, expires_delta: timedelta = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_current_user(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception token_data = TokenData(username=username) except PyJWTError: raise credentials_exception user = get_user(fake_users_db, username=token_data.username) if user is None: raise credentials_exception return user async def get_current_active_user(current_user: User = Depends(get_current_user)): if current_user.disabled: raise HTTPException(status_code=400, detail="Inactive user") return current_user