Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.
 
 
 
 

220 рядки
8.0 KiB

  1. from typing import Dict, Any, List, Optional
  2. import os
  3. import logging
  4. import httpx
  5. import config_env
  6. from jose import jwt, JWTError
  7. from fastapi import HTTPException, status, Depends
  8. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  9. import time
  10. from datetime import datetime, timezone
  11. from fastapi import Request
  12. from starlette.middleware.base import BaseHTTPMiddleware
  13. logger = logging.getLogger("security")
  14. logger.setLevel(logging.WARNING)
  15. KEYCLOAK_ISSUER = config_env.KEYCLOAK_ISSUER
  16. KEYCLOAK_JWKS_URL = config_env.KEYCLOAK_JWKS_URL
  17. KEYCLOAK_AUDIENCES = [
  18. a.strip()
  19. for a in os.getenv("KEYCLOAK_AUDIENCE", "Fastapi").split(",")
  20. if a.strip()
  21. ]
  22. ALGORITHMS = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
  23. logger.info(f"KEYCLOAK_ISSUER={KEYCLOAK_ISSUER}")
  24. logger.info(f"KEYCLOAK_JWKS_URL={KEYCLOAK_JWKS_URL}")
  25. logger.info(f"KEYCLOAK_AUDIENCES={KEYCLOAK_AUDIENCES}")
  26. logger.info(f"ALGORITHMS={ALGORITHMS}")
  27. _http = httpx.AsyncClient(timeout=5.0, verify=False)
  28. _cached_jwks: Optional[Dict[str, Any]] = None
  29. http_bearer = HTTPBearer(auto_error=True)
  30. async def _get_jwks() -> Dict[str, Any]:
  31. global _cached_jwks
  32. if _cached_jwks is None:
  33. logger.info(f"_get_jwks: cache miss, fetching from {KEYCLOAK_JWKS_URL}")
  34. resp = await _http.get(KEYCLOAK_JWKS_URL)
  35. logger.info(f"_get_jwks: response status={resp.status_code}")
  36. resp.raise_for_status()
  37. _cached_jwks = resp.json()
  38. logger.debug(f"_get_jwks: jwks keys={len(_cached_jwks.get('keys', []))}")
  39. else:
  40. logger.debug("_get_jwks: cache hit")
  41. return _cached_jwks
  42. async def _get_key(token: str) -> Dict[str, Any]:
  43. logger.debug(f"_get_key: start, token_prefix={token[:20]}")
  44. headers = jwt.get_unverified_header(token)
  45. kid = headers.get("kid")
  46. logger.debug(f"_get_key: kid={kid}")
  47. jwks = await _get_jwks()
  48. for key in jwks.get("keys", []):
  49. if key.get("kid") == kid:
  50. logger.debug("_get_key: key found in cache")
  51. return key
  52. global _cached_jwks
  53. logger.info("_get_key: key not found, invalidating cache and refetching")
  54. _cached_jwks = None
  55. jwks = await _get_jwks()
  56. for key in jwks.get("keys", []):
  57. if key.get("kid") == kid:
  58. logger.debug("_get_key: key found after cache refresh")
  59. return key
  60. logger.error("_get_key: signing key not found for kid")
  61. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Signing key not found")
  62. def _audiences_from_claims(claims: Dict[str, Any]) -> List[str]:
  63. aud = claims.get("aud")
  64. if isinstance(aud, str):
  65. return [aud]
  66. if isinstance(aud, (list, tuple, set)):
  67. return [str(x) for x in aud]
  68. return []
  69. async def verify_token(token: str) -> Dict[str, Any]:
  70. logger.info("verify_token: start")
  71. logger.debug(f"verify_token: token_prefix={token[:30]}")
  72. try:
  73. key = await _get_key(token)
  74. logger.debug(f"verify_token: using key kid={key.get('kid')}")
  75. claims = jwt.decode(
  76. token,
  77. key,
  78. algorithms=ALGORITHMS,
  79. issuer=KEYCLOAK_ISSUER,
  80. options={"verify_aud": False, "verify_iss": True},
  81. )
  82. logger.info("verify_token: token decoded")
  83. logger.debug(f"verify_token: claims={claims}")
  84. token_audiences = _audiences_from_claims(claims)
  85. logger.info(f"verify_token: token audiences={token_audiences}")
  86. logger.info(f"verify_token: expected audiences={KEYCLOAK_AUDIENCES}")
  87. if KEYCLOAK_AUDIENCES:
  88. if not any(a in token_audiences for a in KEYCLOAK_AUDIENCES):
  89. logger.error(f"verify_token: invalid audience, token_audiences={token_audiences}")
  90. raise HTTPException(
  91. status_code=status.HTTP_401_UNAUTHORIZED,
  92. detail=f"Invalid audience: {token_audiences}",
  93. )
  94. logger.info("verify_token: audience check passed")
  95. return claims
  96. except HTTPException as e:
  97. logger.error(f"verify_token: HTTPException status={e.status_code} detail={e.detail}")
  98. raise
  99. except JWTError as e:
  100. logger.error(f"verify_token: JWTError={e}")
  101. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
  102. except Exception as e:
  103. logger.error(f"verify_token: unexpected error={e}")
  104. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token verification error")
  105. async def get_current_user(
  106. credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
  107. ) -> Dict[str, Any]:
  108. logger.info("get_current_user: start")
  109. logger.debug(f"get_current_user: scheme={credentials.scheme}")
  110. token = credentials.credentials
  111. logger.debug(f"get_current_user: token_prefix={token[:30]}")
  112. claims = await verify_token(token)
  113. logger.info("get_current_user: token verified")
  114. return claims
  115. def require_roles(*roles: str):
  116. async def checker(claims: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
  117. logger.info(f"require_roles: required_roles={roles}")
  118. realm_roles: List[str] = (claims.get("realm_access") or {}).get("roles", []) or []
  119. client_roles: List[str] = []
  120. for client_id, v in (claims.get("resource_access") or {}).items():
  121. client_roles += v.get("roles", [])
  122. logger.debug(f"require_roles: realm_roles={realm_roles}")
  123. logger.debug(f"require_roles: client_roles={client_roles}")
  124. have = set(realm_roles + client_roles)
  125. logger.info(f"require_roles: user_roles={have}")
  126. missing = [r for r in roles if r not in have]
  127. if missing:
  128. logger.error(f"require_roles: missing_roles={missing}")
  129. raise HTTPException(status_code=403, detail=f"Missing roles: {missing}")
  130. logger.info("require_roles: role check passed")
  131. return claims
  132. return checker
  133. #Logging
  134. audit_logger = logging.getLogger("audit")
  135. audit_logger.setLevel(logging.INFO)
  136. def _extract_user_from_claims(claims: Dict[str, Any]) -> Dict[str, str]:
  137. return {
  138. "sub": str(claims.get("sub", "-")),
  139. "username": str(
  140. claims.get("preferred_username")
  141. or claims.get("username")
  142. or claims.get("email")
  143. or "-"
  144. ),
  145. }
  146. class AuditMiddleware(BaseHTTPMiddleware):
  147. async def dispatch(self, request: Request, call_next):
  148. start = time.time()
  149. method = request.method
  150. path = request.url.path
  151. query = request.url.query
  152. # 1) prova a leggere identità "fidata" passata dal proxy (più veloce)
  153. user = request.headers.get("x-authenticated-user")
  154. sub = request.headers.get("x-authenticated-sub")
  155. # 2) fallback: se non arrivano header, prova a decodificare il Bearer token
  156. if (not user or not sub) and "authorization" in request.headers:
  157. auth = request.headers.get("authorization", "")
  158. if auth.lower().startswith("bearer "):
  159. token = auth.split(" ", 1)[1].strip()
  160. try:
  161. claims = await verify_token(token)
  162. ident = _extract_user_from_claims(claims)
  163. sub = sub or ident["sub"]
  164. user = user or ident["username"]
  165. except Exception:
  166. # non bloccare la request per audit; la security vera la gestiscono le Depends
  167. user = user or "-"
  168. sub = sub or "-"
  169. status_code = 500
  170. try:
  171. response = await call_next(request)
  172. status_code = response.status_code
  173. return response
  174. finally:
  175. duration_ms = int((time.time() - start) * 1000)
  176. ts = datetime.now(timezone.utc).isoformat()
  177. full_path = path + (("?" + query) if query else "")
  178. audit_logger.info(
  179. "ts=%s user=%s sub=%s method=%s path=%s status=%s duration_ms=%s",
  180. ts,
  181. user or "-",
  182. sub or "-",
  183. method,
  184. full_path,
  185. status_code,
  186. duration_ms,
  187. )