Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

228 řádky
8.4 KiB

  1. from typing import Dict, Any, List, Optional
  2. import os
  3. import logging
  4. import httpx
  5. import config_env
  6. from jose import jwt, JWTError
  7. from fastapi import HTTPException, status, Depends
  8. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  9. import time
  10. from datetime import datetime, timezone
  11. from fastapi import Request
  12. from starlette.middleware.base import BaseHTTPMiddleware
  13. logger = logging.getLogger("security")
  14. logger.setLevel(logging.WARNING)
  15. KEYCLOAK_ISSUER = config_env.KEYCLOAK_ISSUER
  16. KEYCLOAK_JWKS_URL = config_env.KEYCLOAK_JWKS_URL
  17. KEYCLOAK_AUDIENCES = [
  18. a.strip()
  19. for a in os.getenv("KEYCLOAK_AUDIENCE", "Fastapi").split(",")
  20. if a.strip()
  21. ]
  22. ALGORITHMS = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
  23. logger.info(f"KEYCLOAK_ISSUER={KEYCLOAK_ISSUER}")
  24. logger.info(f"KEYCLOAK_JWKS_URL={KEYCLOAK_JWKS_URL}")
  25. logger.info(f"KEYCLOAK_AUDIENCES={KEYCLOAK_AUDIENCES}")
  26. logger.info(f"ALGORITHMS={ALGORITHMS}")
  27. _http = httpx.AsyncClient(timeout=5.0, verify=False)
  28. _cached_jwks: Optional[Dict[str, Any]] = None
  29. http_bearer = HTTPBearer(auto_error=True)
  30. async def _get_jwks() -> Dict[str, Any]:
  31. global _cached_jwks
  32. if not KEYCLOAK_JWKS_URL or not (
  33. KEYCLOAK_JWKS_URL.startswith("http://") or KEYCLOAK_JWKS_URL.startswith("https://")
  34. ):
  35. logger.error("_get_jwks: invalid KEYCLOAK_JWKS_URL=%r", KEYCLOAK_JWKS_URL)
  36. raise HTTPException(
  37. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  38. detail="Invalid Keycloak JWKS URL configuration",
  39. )
  40. if _cached_jwks is None:
  41. logger.info(f"_get_jwks: cache miss, fetching from {KEYCLOAK_JWKS_URL}")
  42. resp = await _http.get(KEYCLOAK_JWKS_URL)
  43. logger.info(f"_get_jwks: response status={resp.status_code}")
  44. resp.raise_for_status()
  45. _cached_jwks = resp.json()
  46. logger.debug(f"_get_jwks: jwks keys={len(_cached_jwks.get('keys', []))}")
  47. else:
  48. logger.debug("_get_jwks: cache hit")
  49. return _cached_jwks
  50. async def _get_key(token: str) -> Dict[str, Any]:
  51. logger.debug(f"_get_key: start, token_prefix={token[:20]}")
  52. headers = jwt.get_unverified_header(token)
  53. kid = headers.get("kid")
  54. logger.debug(f"_get_key: kid={kid}")
  55. jwks = await _get_jwks()
  56. for key in jwks.get("keys", []):
  57. if key.get("kid") == kid:
  58. logger.debug("_get_key: key found in cache")
  59. return key
  60. global _cached_jwks
  61. logger.info("_get_key: key not found, invalidating cache and refetching")
  62. _cached_jwks = None
  63. jwks = await _get_jwks()
  64. for key in jwks.get("keys", []):
  65. if key.get("kid") == kid:
  66. logger.debug("_get_key: key found after cache refresh")
  67. return key
  68. logger.error("_get_key: signing key not found for kid")
  69. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Signing key not found")
  70. def _audiences_from_claims(claims: Dict[str, Any]) -> List[str]:
  71. aud = claims.get("aud")
  72. if isinstance(aud, str):
  73. return [aud]
  74. if isinstance(aud, (list, tuple, set)):
  75. return [str(x) for x in aud]
  76. return []
  77. async def verify_token(token: str) -> Dict[str, Any]:
  78. logger.info("verify_token: start")
  79. logger.debug(f"verify_token: token_prefix={token[:30]}")
  80. try:
  81. key = await _get_key(token)
  82. logger.debug(f"verify_token: using key kid={key.get('kid')}")
  83. claims = jwt.decode(
  84. token,
  85. key,
  86. algorithms=ALGORITHMS,
  87. issuer=KEYCLOAK_ISSUER,
  88. options={"verify_aud": False, "verify_iss": True},
  89. )
  90. logger.info("verify_token: token decoded")
  91. logger.debug(f"verify_token: claims={claims}")
  92. token_audiences = _audiences_from_claims(claims)
  93. logger.info(f"verify_token: token audiences={token_audiences}")
  94. logger.info(f"verify_token: expected audiences={KEYCLOAK_AUDIENCES}")
  95. if KEYCLOAK_AUDIENCES:
  96. if not any(a in token_audiences for a in KEYCLOAK_AUDIENCES):
  97. logger.error(f"verify_token: invalid audience, token_audiences={token_audiences}")
  98. raise HTTPException(
  99. status_code=status.HTTP_401_UNAUTHORIZED,
  100. detail=f"Invalid audience: {token_audiences}",
  101. )
  102. logger.info("verify_token: audience check passed")
  103. return claims
  104. except HTTPException as e:
  105. logger.error(f"verify_token: HTTPException status={e.status_code} detail={e.detail}")
  106. raise
  107. except JWTError as e:
  108. logger.error(f"verify_token: JWTError={e}")
  109. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
  110. except Exception as e:
  111. logger.error(f"verify_token: unexpected error={e}")
  112. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token verification error")
  113. async def get_current_user(
  114. credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
  115. ) -> Dict[str, Any]:
  116. logger.info("get_current_user: start")
  117. logger.debug(f"get_current_user: scheme={credentials.scheme}")
  118. token = credentials.credentials
  119. logger.debug(f"get_current_user: token_prefix={token[:30]}")
  120. claims = await verify_token(token)
  121. logger.info("get_current_user: token verified")
  122. return claims
  123. def require_roles(*roles: str):
  124. async def checker(claims: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
  125. logger.info(f"require_roles: required_roles={roles}")
  126. realm_roles: List[str] = (claims.get("realm_access") or {}).get("roles", []) or []
  127. client_roles: List[str] = []
  128. for client_id, v in (claims.get("resource_access") or {}).items():
  129. client_roles += v.get("roles", [])
  130. logger.debug(f"require_roles: realm_roles={realm_roles}")
  131. logger.debug(f"require_roles: client_roles={client_roles}")
  132. have = set(realm_roles + client_roles)
  133. logger.info(f"require_roles: user_roles={have}")
  134. missing = [r for r in roles if r not in have]
  135. if missing:
  136. logger.error(f"require_roles: missing_roles={missing}")
  137. raise HTTPException(status_code=403, detail=f"Missing roles: {missing}")
  138. logger.info("require_roles: role check passed")
  139. return claims
  140. return checker
  141. #Logging
  142. audit_logger = logging.getLogger("audit")
  143. audit_logger.setLevel(logging.INFO)
  144. def _extract_user_from_claims(claims: Dict[str, Any]) -> Dict[str, str]:
  145. return {
  146. "sub": str(claims.get("sub", "-")),
  147. "username": str(
  148. claims.get("preferred_username")
  149. or claims.get("username")
  150. or claims.get("email")
  151. or "-"
  152. ),
  153. }
  154. class AuditMiddleware(BaseHTTPMiddleware):
  155. async def dispatch(self, request: Request, call_next):
  156. start = time.time()
  157. method = request.method
  158. path = request.url.path
  159. query = request.url.query
  160. # 1) prova a leggere identità "fidata" passata dal proxy (più veloce)
  161. user = request.headers.get("x-authenticated-user")
  162. sub = request.headers.get("x-authenticated-sub")
  163. # 2) fallback: se non arrivano header, prova a decodificare il Bearer token
  164. if (not user or not sub) and "authorization" in request.headers:
  165. auth = request.headers.get("authorization", "")
  166. if auth.lower().startswith("bearer "):
  167. token = auth.split(" ", 1)[1].strip()
  168. try:
  169. claims = await verify_token(token)
  170. ident = _extract_user_from_claims(claims)
  171. sub = sub or ident["sub"]
  172. user = user or ident["username"]
  173. except Exception:
  174. # non bloccare la request per audit; la security vera la gestiscono le Depends
  175. user = user or "-"
  176. sub = sub or "-"
  177. status_code = 500
  178. try:
  179. response = await call_next(request)
  180. status_code = response.status_code
  181. return response
  182. finally:
  183. duration_ms = int((time.time() - start) * 1000)
  184. ts = datetime.now(timezone.utc).isoformat()
  185. full_path = path + (("?" + query) if query else "")
  186. audit_logger.info(
  187. "ts=%s user=%s sub=%s method=%s path=%s status=%s duration_ms=%s",
  188. ts,
  189. user or "-",
  190. sub or "-",
  191. method,
  192. full_path,
  193. status_code,
  194. duration_ms,
  195. )