You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

61 line
2.1 KiB

  1. import typing
  2. from starlette.datastructures import URL, Headers
  3. from starlette.responses import PlainTextResponse, RedirectResponse, Response
  4. from starlette.types import ASGIApp, Receive, Scope, Send
  5. ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
  6. class TrustedHostMiddleware:
  7. def __init__(
  8. self,
  9. app: ASGIApp,
  10. allowed_hosts: typing.Sequence[str] = None,
  11. www_redirect: bool = True,
  12. ) -> None:
  13. if allowed_hosts is None:
  14. allowed_hosts = ["*"]
  15. for pattern in allowed_hosts:
  16. assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
  17. if pattern.startswith("*") and pattern != "*":
  18. assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
  19. self.app = app
  20. self.allowed_hosts = list(allowed_hosts)
  21. self.allow_any = "*" in allowed_hosts
  22. self.www_redirect = www_redirect
  23. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  24. if self.allow_any or scope["type"] not in (
  25. "http",
  26. "websocket",
  27. ): # pragma: no cover
  28. await self.app(scope, receive, send)
  29. return
  30. headers = Headers(scope=scope)
  31. host = headers.get("host", "").split(":")[0]
  32. is_valid_host = False
  33. found_www_redirect = False
  34. for pattern in self.allowed_hosts:
  35. if host == pattern or (
  36. pattern.startswith("*") and host.endswith(pattern[1:])
  37. ):
  38. is_valid_host = True
  39. break
  40. elif "www." + host == pattern:
  41. found_www_redirect = True
  42. if is_valid_host:
  43. await self.app(scope, receive, send)
  44. else:
  45. response: Response
  46. if found_www_redirect and self.www_redirect:
  47. url = URL(scope=scope)
  48. redirect_url = url.replace(netloc="www." + url.netloc)
  49. response = RedirectResponse(url=str(redirect_url))
  50. else:
  51. response = PlainTextResponse("Invalid host header", status_code=400)
  52. await response(scope, receive, send)