|
- import typing
-
- from starlette.datastructures import URL, Headers
- from starlette.responses import PlainTextResponse, RedirectResponse, Response
- from starlette.types import ASGIApp, Receive, Scope, Send
-
- ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
-
-
- class TrustedHostMiddleware:
- def __init__(
- self,
- app: ASGIApp,
- allowed_hosts: typing.Sequence[str] = None,
- www_redirect: bool = True,
- ) -> None:
- if allowed_hosts is None:
- allowed_hosts = ["*"]
-
- for pattern in allowed_hosts:
- assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
- if pattern.startswith("*") and pattern != "*":
- assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
- self.app = app
- self.allowed_hosts = list(allowed_hosts)
- self.allow_any = "*" in allowed_hosts
- self.www_redirect = www_redirect
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if self.allow_any or scope["type"] not in (
- "http",
- "websocket",
- ): # pragma: no cover
- await self.app(scope, receive, send)
- return
-
- headers = Headers(scope=scope)
- host = headers.get("host", "").split(":")[0]
- is_valid_host = False
- found_www_redirect = False
- for pattern in self.allowed_hosts:
- if host == pattern or (
- pattern.startswith("*") and host.endswith(pattern[1:])
- ):
- is_valid_host = True
- break
- elif "www." + host == pattern:
- found_www_redirect = True
-
- if is_valid_host:
- await self.app(scope, receive, send)
- else:
- response: Response
- if found_www_redirect and self.www_redirect:
- url = URL(scope=scope)
- redirect_url = url.replace(netloc="www." + url.netloc)
- response = RedirectResponse(url=str(redirect_url))
- else:
- response = PlainTextResponse("Invalid host header", status_code=400)
- await response(scope, receive, send)
|