|
- from datetime import datetime, timedelta
- from typing import Optional
-
- import jwt
- from fastapi import Depends, Header
- from jwt import PyJWTError
- from starlette.exceptions import HTTPException
- from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
- from fastapi.security import OAuth2PasswordRequestForm, OAuth2
- from crud.user import get_user
- from models.token import TokenPayload
- from models.user import User
- from starlette.requests import Request
- from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
- from .config import JWT_TOKEN_PREFIX, SECRET_KEY
-
- ALGORITHM = "HS256"
- access_token_jwt_subject = "access"
-
-
- def _get_authorization_token(authorization: str = Header(...)):
- token_prefix, token = authorization.split(" ")
- if token_prefix != JWT_TOKEN_PREFIX:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Invalid authorization type"
- )
-
- return token
-
-
- # async def _get_current_user(
- # db: DataBase = Depends(get_database), token: str = Depends(_get_authorization_token)
- # ) -> User:
- # try:
- # payload = jwt.decode(token, str(SECRET_KEY), algorithms=[ALGORITHM])
- # token_data = TokenPayload(**payload)
- # except PyJWTError:
- # raise HTTPException(
- # status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
- # )
-
- # async with db.pool.acquire() as conn:
- # dbuser = await get_user(conn, token_data.username)
- # if not dbuser:
- # raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found")
-
- # user = User(**dbuser.dict(), token=token)
- # return user
-
-
- def _get_authorization_token_optional(authorization: str = Header(None)):
- if authorization:
- return _get_authorization_token(authorization)
- return ""
-
-
- # async def _get_current_user_optional(
- # db: DataBase = Depends(get_database),
- # token: str = Depends(_get_authorization_token_optional),
- # ) -> Optional[User]:
- # if token:
- # return await _get_current_user(db, token)
-
- # return None
-
-
- def get_current_user_authorizer(*, required: bool = True):
- if required:
- return _get_current_user
- else:
- return _get_current_user_optional
-
-
- def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None):
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- to_encode.update({"exp": expire, "sub": access_token_jwt_subject})
- encoded_jwt = jwt.encode(to_encode, str(SECRET_KEY), algorithm=ALGORITHM)
- return encoded_jwt
-
-
-
-
-
- ##########################
- class OAuth2PasswordBearerCookie(OAuth2):
- def __init__(
- self,
- tokenUrl: str,
- scheme_name: str = None,
- scopes: dict = None,
- auto_error: bool = True,
- ):
- if not scopes:
- scopes = {}
- flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
- super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
-
- async def __call__(self, request: Request) -> Optional[str]:
- header_authorization: str = request.headers.get("Authorization")
- cookie_authorization: str = request.cookies.get("Authorization")
-
- header_scheme, header_param = get_authorization_scheme_param(
- header_authorization
- )
- cookie_scheme, cookie_param = get_authorization_scheme_param(
- cookie_authorization
- )
-
- if header_scheme.lower() == "bearer":
- authorization = True
- scheme = header_scheme
- param = header_param
-
- elif cookie_scheme.lower() == "bearer":
- authorization = True
- scheme = cookie_scheme
- param = cookie_param
-
- else:
- authorization = False
-
- if not authorization or scheme.lower() != "bearer":
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
- )
- else:
- return None
- return param
-
-
- oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token")
-
-
-
- def create_access_token(*, data: dict, expires_delta: timedelta = None):
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- return encoded_jwt
-
-
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- credentials_exception = HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
- )
- try:
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- username: str = payload.get("sub")
- if username is None:
- raise credentials_exception
- token_data = TokenData(username=username)
- except PyJWTError:
- raise credentials_exception
- user = get_user(fake_users_db, username=token_data.username)
- if user is None:
- raise credentials_exception
- return user
-
-
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
- if current_user.disabled:
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
|