|
- from typing import Any, Dict, List, Optional, Union
-
- from fastapi.exceptions import HTTPException
- from fastapi.openapi.models import OAuth2 as OAuth2Model
- from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
- from fastapi.param_functions import Form
- from fastapi.security.base import SecurityBase
- from fastapi.security.utils import get_authorization_scheme_param
- from starlette.requests import Request
- from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
-
-
- class OAuth2PasswordRequestForm:
- """
- This is a dependency class, use it like:
-
- @app.post("/login")
- def login(form_data: OAuth2PasswordRequestForm = Depends()):
- data = form_data.parse()
- print(data.username)
- print(data.password)
- for scope in data.scopes:
- print(scope)
- if data.client_id:
- print(data.client_id)
- if data.client_secret:
- print(data.client_secret)
- return data
-
-
- It creates the following Form request parameters in your endpoint:
-
- grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
- Nevertheless, this dependency class is permissive and allows not passing it. If you want to enforce it,
- use instead the OAuth2PasswordRequestFormStrict dependency.
- username: username string. The OAuth2 spec requires the exact field name "username".
- password: password string. The OAuth2 spec requires the exact field name "password".
- scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
- "items:read items:write users:read profile openid"
- client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
- using HTTP Basic auth, as: client_id:client_secret
- client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
- using HTTP Basic auth, as: client_id:client_secret
- """
-
- def __init__(
- self,
- grant_type: str = Form(None, regex="password"),
- username: str = Form(...),
- password: str = Form(...),
- scope: str = Form(""),
- client_id: Optional[str] = Form(None),
- client_secret: Optional[str] = Form(None),
- ):
- self.grant_type = grant_type
- self.username = username
- self.password = password
- self.scopes = scope.split()
- self.client_id = client_id
- self.client_secret = client_secret
-
-
- class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
- """
- This is a dependency class, use it like:
-
- @app.post("/login")
- def login(form_data: OAuth2PasswordRequestFormStrict = Depends()):
- data = form_data.parse()
- print(data.username)
- print(data.password)
- for scope in data.scopes:
- print(scope)
- if data.client_id:
- print(data.client_id)
- if data.client_secret:
- print(data.client_secret)
- return data
-
-
- It creates the following Form request parameters in your endpoint:
-
- grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
- This dependency is strict about it. If you want to be permissive, use instead the
- OAuth2PasswordRequestForm dependency class.
- username: username string. The OAuth2 spec requires the exact field name "username".
- password: password string. The OAuth2 spec requires the exact field name "password".
- scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
- "items:read items:write users:read profile openid"
- client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
- using HTTP Basic auth, as: client_id:client_secret
- client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
- using HTTP Basic auth, as: client_id:client_secret
- """
-
- def __init__(
- self,
- grant_type: str = Form(..., regex="password"),
- username: str = Form(...),
- password: str = Form(...),
- scope: str = Form(""),
- client_id: Optional[str] = Form(None),
- client_secret: Optional[str] = Form(None),
- ):
- super().__init__(
- grant_type=grant_type,
- username=username,
- password=password,
- scope=scope,
- client_id=client_id,
- client_secret=client_secret,
- )
-
-
- class OAuth2(SecurityBase):
- def __init__(
- self,
- *,
- flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
- scheme_name: Optional[str] = None,
- description: Optional[str] = None,
- auto_error: Optional[bool] = True
- ):
- self.model = OAuth2Model(flows=flows, description=description)
- self.scheme_name = scheme_name or self.__class__.__name__
- self.auto_error = auto_error
-
- async def __call__(self, request: Request) -> Optional[str]:
- authorization: str = request.headers.get("Authorization")
- if not authorization:
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
- )
- else:
- return None
- return authorization
-
-
- class OAuth2PasswordBearer(OAuth2):
- def __init__(
- self,
- tokenUrl: str,
- scheme_name: Optional[str] = None,
- scopes: Optional[Dict[str, str]] = None,
- description: Optional[str] = None,
- auto_error: bool = True,
- ):
- if not scopes:
- scopes = {}
- flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
- super().__init__(
- flows=flows,
- scheme_name=scheme_name,
- description=description,
- auto_error=auto_error,
- )
-
- async def __call__(self, request: Request) -> Optional[str]:
- authorization: str = request.headers.get("Authorization")
- scheme, param = get_authorization_scheme_param(authorization)
- if not authorization or scheme.lower() != "bearer":
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_401_UNAUTHORIZED,
- detail="Not authenticated",
- headers={"WWW-Authenticate": "Bearer"},
- )
- else:
- return None
- return param
-
-
- class OAuth2AuthorizationCodeBearer(OAuth2):
- def __init__(
- self,
- authorizationUrl: str,
- tokenUrl: str,
- refreshUrl: Optional[str] = None,
- scheme_name: Optional[str] = None,
- scopes: Optional[Dict[str, str]] = None,
- description: Optional[str] = None,
- auto_error: bool = True,
- ):
- if not scopes:
- scopes = {}
- flows = OAuthFlowsModel(
- authorizationCode={
- "authorizationUrl": authorizationUrl,
- "tokenUrl": tokenUrl,
- "refreshUrl": refreshUrl,
- "scopes": scopes,
- }
- )
- super().__init__(
- flows=flows,
- scheme_name=scheme_name,
- description=description,
- auto_error=auto_error,
- )
-
- async def __call__(self, request: Request) -> Optional[str]:
- authorization: str = request.headers.get("Authorization")
- scheme, param = get_authorization_scheme_param(authorization)
- if not authorization or scheme.lower() != "bearer":
- if self.auto_error:
- raise HTTPException(
- status_code=HTTP_401_UNAUTHORIZED,
- detail="Not authenticated",
- headers={"WWW-Authenticate": "Bearer"},
- )
- else:
- return None # pragma: nocover
- return param
-
-
- class SecurityScopes:
- def __init__(self, scopes: Optional[List[str]] = None):
- self.scopes = scopes or []
- self.scope_str = " ".join(self.scopes)
|