|
- # security.py
- from typing import Dict, Any, List, Optional
- import os
- import logging
- import httpx
- from jose import jwt, JWTError
- from fastapi import HTTPException, status, Depends
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-
- logger = logging.getLogger("security")
-
- # === CONFIG ===
- KEYCLOAK_ISSUER = os.getenv(
- "KEYCLOAK_ISSUER",
- "https://192.168.1.3:10002/realms/API.Server.local",
- )
- KEYCLOAK_JWKS_URL = os.getenv(
- "KEYCLOAK_JWKS_URL",
- "https://192.168.1.3:10002/realms/API.Server.local/protocol/openid-connect/certs",
- )
- KEYCLOAK_AUDIENCE = os.getenv("KEYCLOAK_AUDIENCE", "Fastapi")
-
- ALGORITHMS = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
-
- # Per test con certificato self-signed. In prod: metti verify="/path/CA.crt"
- _http = httpx.AsyncClient(timeout=5.0, verify=False)
-
- _cached_jwks: Optional[Dict[str, Any]] = None
-
- # NON chiamarla 'security' per evitare conflitti col nome del modulo.
- http_bearer = HTTPBearer(auto_error=True)
-
-
- async def _get_jwks() -> Dict[str, Any]:
- global _cached_jwks
- if _cached_jwks is None:
- logger.info(f"Fetching JWKS from: {KEYCLOAK_JWKS_URL}")
- resp = await _http.get(KEYCLOAK_JWKS_URL)
- resp.raise_for_status()
- _cached_jwks = resp.json()
- return _cached_jwks
-
-
- async def _get_key(token: str) -> Dict[str, Any]:
- headers = jwt.get_unverified_header(token)
- kid = headers.get("kid")
- jwks = await _get_jwks()
- for key in jwks.get("keys", []):
- if key.get("kid") == kid:
- return key
- # chiave ruotata? invalida la cache e riprova
- global _cached_jwks
- _cached_jwks = None
- jwks = await _get_jwks()
- for key in jwks.get("keys", []):
- if key.get("kid") == kid:
- return key
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Signing key not found")
-
-
- async def verify_token(token: str) -> Dict[str, Any]:
- try:
- key = await _get_key(token)
- claims = jwt.decode(
- token,
- key,
- algorithms=ALGORITHMS,
- audience=KEYCLOAK_AUDIENCE,
- issuer=KEYCLOAK_ISSUER,
- options={"verify_aud": True, "verify_iss": True},
- )
- return claims
- except JWTError as e:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
-
-
- async def get_current_user(
- credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
- ) -> Dict[str, Any]:
- token = credentials.credentials
- return await verify_token(token)
-
-
- def require_roles(*roles: str):
- async def checker(claims: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
- # ruoli realm
- realm_roles: List[str] = (claims.get("realm_access") or {}).get("roles", []) or []
- # ruoli client
- client_roles: List[str] = []
- for v in (claims.get("resource_access") or {}).values():
- client_roles += v.get("roles", [])
- have = set(realm_roles + client_roles)
- missing = [r for r in roles if r not in have]
- if missing:
- raise HTTPException(status_code=403, detail=f"Missing roles: {missing}")
- return claims
- return checker
|