You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

61 lines
2.2 KiB

  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. from starlette.datastructures import URL, Headers
  4. from starlette.responses import PlainTextResponse, RedirectResponse, Response
  5. from starlette.types import ASGIApp, Receive, Scope, Send
  6. ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
  7. class TrustedHostMiddleware:
  8. def __init__(
  9. self,
  10. app: ASGIApp,
  11. allowed_hosts: Sequence[str] | None = None,
  12. www_redirect: bool = True,
  13. ) -> None:
  14. if allowed_hosts is None:
  15. allowed_hosts = ["*"]
  16. for pattern in allowed_hosts:
  17. assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
  18. if pattern.startswith("*") and pattern != "*":
  19. assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
  20. self.app = app
  21. self.allowed_hosts = list(allowed_hosts)
  22. self.allow_any = "*" in allowed_hosts
  23. self.www_redirect = www_redirect
  24. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  25. if self.allow_any or scope["type"] not in (
  26. "http",
  27. "websocket",
  28. ): # pragma: no cover
  29. await self.app(scope, receive, send)
  30. return
  31. headers = Headers(scope=scope)
  32. host = headers.get("host", "").split(":")[0]
  33. is_valid_host = False
  34. found_www_redirect = False
  35. for pattern in self.allowed_hosts:
  36. if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
  37. is_valid_host = True
  38. break
  39. elif "www." + host == pattern:
  40. found_www_redirect = True
  41. if is_valid_host:
  42. await self.app(scope, receive, send)
  43. else:
  44. response: Response
  45. if found_www_redirect and self.www_redirect:
  46. url = URL(scope=scope)
  47. redirect_url = url.replace(netloc="www." + url.netloc)
  48. response = RedirectResponse(url=str(redirect_url))
  49. else:
  50. response = PlainTextResponse("Invalid host header", status_code=400)
  51. await response(scope, receive, send)