選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。
 
 
 
 

106 行
3.4 KiB

  1. # security.py
  2. from typing import Dict, Any, List, Optional
  3. import os
  4. import logging
  5. import httpx
  6. import config_env
  7. from jose import jwt, JWTError
  8. from fastapi import HTTPException, status, Depends
  9. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  10. logger = logging.getLogger("security")
  11. # === CONFIG ===
  12. #KEYCLOAK_ISSUER = os.getenv(
  13. # "KEYCLOAK_ISSUER",
  14. # "https://10.251.0.30:10002/realms/API.Server.local",
  15. #"https://192.168.1.3:10002/realms/API.Server.local",
  16. #)
  17. #KEYCLOAK_JWKS_URL = os.getenv(
  18. # "KEYCLOAK_JWKS_URL",
  19. # "https://10.251.0.30:10002/realms/API.Server.local/protocol/openid-connect/certs",
  20. #"https://192.168.1.3:10002/realms/API.Server.local/protocol/openid-connect/certs",
  21. #)
  22. KEYCLOAK_ISSUER = config_env.KEYCLOAK_ISSUER
  23. KEYCLOAK_JWKS_URL = config_env.KEYCLOAK_JWKS_URL
  24. KEYCLOAK_AUDIENCE = os.getenv("KEYCLOAK_AUDIENCE", "Fastapi")
  25. ALGORITHMS = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
  26. # Per test con certificato self-signed. In prod: metti verify="/path/CA.crt"
  27. _http = httpx.AsyncClient(timeout=5.0, verify=False)
  28. _cached_jwks: Optional[Dict[str, Any]] = None
  29. # NON chiamarla 'security' per evitare conflitti col nome del modulo.
  30. http_bearer = HTTPBearer(auto_error=True)
  31. async def _get_jwks() -> Dict[str, Any]:
  32. global _cached_jwks
  33. if _cached_jwks is None:
  34. logger.info(f"Fetching JWKS from: {KEYCLOAK_JWKS_URL}")
  35. resp = await _http.get(KEYCLOAK_JWKS_URL)
  36. resp.raise_for_status()
  37. _cached_jwks = resp.json()
  38. return _cached_jwks
  39. async def _get_key(token: str) -> Dict[str, Any]:
  40. headers = jwt.get_unverified_header(token)
  41. kid = headers.get("kid")
  42. jwks = await _get_jwks()
  43. for key in jwks.get("keys", []):
  44. if key.get("kid") == kid:
  45. return key
  46. # chiave ruotata? invalida la cache e riprova
  47. global _cached_jwks
  48. _cached_jwks = None
  49. jwks = await _get_jwks()
  50. for key in jwks.get("keys", []):
  51. if key.get("kid") == kid:
  52. return key
  53. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Signing key not found")
  54. async def verify_token(token: str) -> Dict[str, Any]:
  55. try:
  56. key = await _get_key(token)
  57. claims = jwt.decode(
  58. token,
  59. key,
  60. algorithms=ALGORITHMS,
  61. audience=KEYCLOAK_AUDIENCE,
  62. issuer=KEYCLOAK_ISSUER,
  63. options={"verify_aud": True, "verify_iss": True},
  64. )
  65. return claims
  66. except JWTError as e:
  67. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
  68. async def get_current_user(
  69. credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
  70. ) -> Dict[str, Any]:
  71. token = credentials.credentials
  72. return await verify_token(token)
  73. def require_roles(*roles: str):
  74. async def checker(claims: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
  75. # ruoli realm
  76. realm_roles: List[str] = (claims.get("realm_access") or {}).get("roles", []) or []
  77. # ruoli client
  78. client_roles: List[str] = []
  79. for v in (claims.get("resource_access") or {}).values():
  80. client_roles += v.get("roles", [])
  81. have = set(realm_roles + client_roles)
  82. missing = [r for r in roles if r not in have]
  83. if missing:
  84. raise HTTPException(status_code=403, detail=f"Missing roles: {missing}")
  85. return claims
  86. return checker