from typing import Dict, Any, List, Optional import os import logging import httpx import config_env from jose import jwt, JWTError from fastapi import HTTPException, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import time from datetime import datetime, timezone from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware logger = logging.getLogger("security") logger.setLevel(logging.WARNING) KEYCLOAK_ISSUER = config_env.KEYCLOAK_ISSUER KEYCLOAK_JWKS_URL = config_env.KEYCLOAK_JWKS_URL KEYCLOAK_AUDIENCES = [ a.strip() for a in os.getenv("KEYCLOAK_AUDIENCE", "Fastapi").split(",") if a.strip() ] ALGORITHMS = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"] logger.info(f"KEYCLOAK_ISSUER={KEYCLOAK_ISSUER}") logger.info(f"KEYCLOAK_JWKS_URL={KEYCLOAK_JWKS_URL}") logger.info(f"KEYCLOAK_AUDIENCES={KEYCLOAK_AUDIENCES}") logger.info(f"ALGORITHMS={ALGORITHMS}") _http = httpx.AsyncClient(timeout=5.0, verify=False) _cached_jwks: Optional[Dict[str, Any]] = None http_bearer = HTTPBearer(auto_error=True) async def _get_jwks() -> Dict[str, Any]: global _cached_jwks if _cached_jwks is None: logger.info(f"_get_jwks: cache miss, fetching from {KEYCLOAK_JWKS_URL}") resp = await _http.get(KEYCLOAK_JWKS_URL) logger.info(f"_get_jwks: response status={resp.status_code}") resp.raise_for_status() _cached_jwks = resp.json() logger.debug(f"_get_jwks: jwks keys={len(_cached_jwks.get('keys', []))}") else: logger.debug("_get_jwks: cache hit") return _cached_jwks async def _get_key(token: str) -> Dict[str, Any]: logger.debug(f"_get_key: start, token_prefix={token[:20]}") headers = jwt.get_unverified_header(token) kid = headers.get("kid") logger.debug(f"_get_key: kid={kid}") jwks = await _get_jwks() for key in jwks.get("keys", []): if key.get("kid") == kid: logger.debug("_get_key: key found in cache") return key global _cached_jwks logger.info("_get_key: key not found, invalidating cache and refetching") _cached_jwks = None jwks = await _get_jwks() for key in jwks.get("keys", []): if key.get("kid") == kid: logger.debug("_get_key: key found after cache refresh") return key logger.error("_get_key: signing key not found for kid") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Signing key not found") def _audiences_from_claims(claims: Dict[str, Any]) -> List[str]: aud = claims.get("aud") if isinstance(aud, str): return [aud] if isinstance(aud, (list, tuple, set)): return [str(x) for x in aud] return [] async def verify_token(token: str) -> Dict[str, Any]: logger.info("verify_token: start") logger.debug(f"verify_token: token_prefix={token[:30]}") try: key = await _get_key(token) logger.debug(f"verify_token: using key kid={key.get('kid')}") claims = jwt.decode( token, key, algorithms=ALGORITHMS, issuer=KEYCLOAK_ISSUER, options={"verify_aud": False, "verify_iss": True}, ) logger.info("verify_token: token decoded") logger.debug(f"verify_token: claims={claims}") token_audiences = _audiences_from_claims(claims) logger.info(f"verify_token: token audiences={token_audiences}") logger.info(f"verify_token: expected audiences={KEYCLOAK_AUDIENCES}") if KEYCLOAK_AUDIENCES: if not any(a in token_audiences for a in KEYCLOAK_AUDIENCES): logger.error(f"verify_token: invalid audience, token_audiences={token_audiences}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid audience: {token_audiences}", ) logger.info("verify_token: audience check passed") return claims except HTTPException as e: logger.error(f"verify_token: HTTPException status={e.status_code} detail={e.detail}") raise except JWTError as e: logger.error(f"verify_token: JWTError={e}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)) except Exception as e: logger.error(f"verify_token: unexpected error={e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token verification error") async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(http_bearer), ) -> Dict[str, Any]: logger.info("get_current_user: start") logger.debug(f"get_current_user: scheme={credentials.scheme}") token = credentials.credentials logger.debug(f"get_current_user: token_prefix={token[:30]}") claims = await verify_token(token) logger.info("get_current_user: token verified") return claims def require_roles(*roles: str): async def checker(claims: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]: logger.info(f"require_roles: required_roles={roles}") realm_roles: List[str] = (claims.get("realm_access") or {}).get("roles", []) or [] client_roles: List[str] = [] for client_id, v in (claims.get("resource_access") or {}).items(): client_roles += v.get("roles", []) logger.debug(f"require_roles: realm_roles={realm_roles}") logger.debug(f"require_roles: client_roles={client_roles}") have = set(realm_roles + client_roles) logger.info(f"require_roles: user_roles={have}") missing = [r for r in roles if r not in have] if missing: logger.error(f"require_roles: missing_roles={missing}") raise HTTPException(status_code=403, detail=f"Missing roles: {missing}") logger.info("require_roles: role check passed") return claims return checker #Logging audit_logger = logging.getLogger("audit") audit_logger.setLevel(logging.INFO) def _extract_user_from_claims(claims: Dict[str, Any]) -> Dict[str, str]: return { "sub": str(claims.get("sub", "-")), "username": str( claims.get("preferred_username") or claims.get("username") or claims.get("email") or "-" ), } class AuditMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): start = time.time() method = request.method path = request.url.path query = request.url.query # 1) prova a leggere identità "fidata" passata dal proxy (più veloce) user = request.headers.get("x-authenticated-user") sub = request.headers.get("x-authenticated-sub") # 2) fallback: se non arrivano header, prova a decodificare il Bearer token if (not user or not sub) and "authorization" in request.headers: auth = request.headers.get("authorization", "") if auth.lower().startswith("bearer "): token = auth.split(" ", 1)[1].strip() try: claims = await verify_token(token) ident = _extract_user_from_claims(claims) sub = sub or ident["sub"] user = user or ident["username"] except Exception: # non bloccare la request per audit; la security vera la gestiscono le Depends user = user or "-" sub = sub or "-" status_code = 500 try: response = await call_next(request) status_code = response.status_code return response finally: duration_ms = int((time.time() - start) * 1000) ts = datetime.now(timezone.utc).isoformat() full_path = path + (("?" + query) if query else "") audit_logger.info( "ts=%s user=%s sub=%s method=%s path=%s status=%s duration_ms=%s", ts, user or "-", sub or "-", method, full_path, status_code, duration_ms, )