|
- from typing import Dict, Any, List, Optional
- import os
- import logging
- import httpx
- import config_env
- from jose import jwt, JWTError
- from fastapi import HTTPException, status, Depends
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
- import time
- from datetime import datetime, timezone
- from fastapi import Request
- from starlette.middleware.base import BaseHTTPMiddleware
-
- logger = logging.getLogger("security")
- logger.setLevel(logging.WARNING)
-
- KEYCLOAK_ISSUER = config_env.KEYCLOAK_ISSUER
- KEYCLOAK_JWKS_URL = config_env.KEYCLOAK_JWKS_URL
-
- KEYCLOAK_AUDIENCES = [
- a.strip()
- for a in os.getenv("KEYCLOAK_AUDIENCE", "Fastapi").split(",")
- if a.strip()
- ]
-
- ALGORITHMS = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
-
- logger.info(f"KEYCLOAK_ISSUER={KEYCLOAK_ISSUER}")
- logger.info(f"KEYCLOAK_JWKS_URL={KEYCLOAK_JWKS_URL}")
- logger.info(f"KEYCLOAK_AUDIENCES={KEYCLOAK_AUDIENCES}")
- logger.info(f"ALGORITHMS={ALGORITHMS}")
-
- _http = httpx.AsyncClient(timeout=5.0, verify=False)
-
- _cached_jwks: Optional[Dict[str, Any]] = None
-
- http_bearer = HTTPBearer(auto_error=True)
-
-
- async def _get_jwks() -> Dict[str, Any]:
- global _cached_jwks
- if _cached_jwks is None:
- logger.info(f"_get_jwks: cache miss, fetching from {KEYCLOAK_JWKS_URL}")
- resp = await _http.get(KEYCLOAK_JWKS_URL)
- logger.info(f"_get_jwks: response status={resp.status_code}")
- resp.raise_for_status()
- _cached_jwks = resp.json()
- logger.debug(f"_get_jwks: jwks keys={len(_cached_jwks.get('keys', []))}")
- else:
- logger.debug("_get_jwks: cache hit")
- return _cached_jwks
-
-
- async def _get_key(token: str) -> Dict[str, Any]:
- logger.debug(f"_get_key: start, token_prefix={token[:20]}")
- headers = jwt.get_unverified_header(token)
- kid = headers.get("kid")
- logger.debug(f"_get_key: kid={kid}")
- jwks = await _get_jwks()
- for key in jwks.get("keys", []):
- if key.get("kid") == kid:
- logger.debug("_get_key: key found in cache")
- return key
- global _cached_jwks
- logger.info("_get_key: key not found, invalidating cache and refetching")
- _cached_jwks = None
- jwks = await _get_jwks()
- for key in jwks.get("keys", []):
- if key.get("kid") == kid:
- logger.debug("_get_key: key found after cache refresh")
- return key
- logger.error("_get_key: signing key not found for kid")
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Signing key not found")
-
-
- def _audiences_from_claims(claims: Dict[str, Any]) -> List[str]:
- aud = claims.get("aud")
- if isinstance(aud, str):
- return [aud]
- if isinstance(aud, (list, tuple, set)):
- return [str(x) for x in aud]
- return []
-
-
- async def verify_token(token: str) -> Dict[str, Any]:
- logger.info("verify_token: start")
- logger.debug(f"verify_token: token_prefix={token[:30]}")
- try:
- key = await _get_key(token)
- logger.debug(f"verify_token: using key kid={key.get('kid')}")
- claims = jwt.decode(
- token,
- key,
- algorithms=ALGORITHMS,
- issuer=KEYCLOAK_ISSUER,
- options={"verify_aud": False, "verify_iss": True},
- )
- logger.info("verify_token: token decoded")
- logger.debug(f"verify_token: claims={claims}")
- token_audiences = _audiences_from_claims(claims)
- logger.info(f"verify_token: token audiences={token_audiences}")
- logger.info(f"verify_token: expected audiences={KEYCLOAK_AUDIENCES}")
- if KEYCLOAK_AUDIENCES:
- if not any(a in token_audiences for a in KEYCLOAK_AUDIENCES):
- logger.error(f"verify_token: invalid audience, token_audiences={token_audiences}")
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=f"Invalid audience: {token_audiences}",
- )
- logger.info("verify_token: audience check passed")
- return claims
- except HTTPException as e:
- logger.error(f"verify_token: HTTPException status={e.status_code} detail={e.detail}")
- raise
- except JWTError as e:
- logger.error(f"verify_token: JWTError={e}")
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
- except Exception as e:
- logger.error(f"verify_token: unexpected error={e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token verification error")
-
-
- async def get_current_user(
- credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
- ) -> Dict[str, Any]:
- logger.info("get_current_user: start")
- logger.debug(f"get_current_user: scheme={credentials.scheme}")
- token = credentials.credentials
- logger.debug(f"get_current_user: token_prefix={token[:30]}")
- claims = await verify_token(token)
- logger.info("get_current_user: token verified")
- return claims
-
-
- def require_roles(*roles: str):
- async def checker(claims: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
- logger.info(f"require_roles: required_roles={roles}")
- realm_roles: List[str] = (claims.get("realm_access") or {}).get("roles", []) or []
- client_roles: List[str] = []
- for client_id, v in (claims.get("resource_access") or {}).items():
- client_roles += v.get("roles", [])
- logger.debug(f"require_roles: realm_roles={realm_roles}")
- logger.debug(f"require_roles: client_roles={client_roles}")
- have = set(realm_roles + client_roles)
- logger.info(f"require_roles: user_roles={have}")
- missing = [r for r in roles if r not in have]
- if missing:
- logger.error(f"require_roles: missing_roles={missing}")
- raise HTTPException(status_code=403, detail=f"Missing roles: {missing}")
- logger.info("require_roles: role check passed")
- return claims
- return checker
-
-
- #Logging
- audit_logger = logging.getLogger("audit")
- audit_logger.setLevel(logging.INFO)
-
-
- def _extract_user_from_claims(claims: Dict[str, Any]) -> Dict[str, str]:
- return {
- "sub": str(claims.get("sub", "-")),
- "username": str(
- claims.get("preferred_username")
- or claims.get("username")
- or claims.get("email")
- or "-"
- ),
- }
-
-
- class AuditMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- start = time.time()
-
- method = request.method
- path = request.url.path
- query = request.url.query
-
- # 1) prova a leggere identità "fidata" passata dal proxy (più veloce)
- user = request.headers.get("x-authenticated-user")
- sub = request.headers.get("x-authenticated-sub")
-
- # 2) fallback: se non arrivano header, prova a decodificare il Bearer token
- if (not user or not sub) and "authorization" in request.headers:
- auth = request.headers.get("authorization", "")
- if auth.lower().startswith("bearer "):
- token = auth.split(" ", 1)[1].strip()
- try:
- claims = await verify_token(token)
- ident = _extract_user_from_claims(claims)
- sub = sub or ident["sub"]
- user = user or ident["username"]
- except Exception:
- # non bloccare la request per audit; la security vera la gestiscono le Depends
- user = user or "-"
- sub = sub or "-"
-
- status_code = 500
- try:
- response = await call_next(request)
- status_code = response.status_code
- return response
- finally:
- duration_ms = int((time.time() - start) * 1000)
- ts = datetime.now(timezone.utc).isoformat()
-
- full_path = path + (("?" + query) if query else "")
- audit_logger.info(
- "ts=%s user=%s sub=%s method=%s path=%s status=%s duration_ms=%s",
- ts,
- user or "-",
- sub or "-",
- method,
- full_path,
- status_code,
- duration_ms,
- )
|