|
- import typing
- from os import PathLike
-
- from starlette.background import BackgroundTask
- from starlette.responses import Response
- from starlette.types import Receive, Scope, Send
-
- try:
- import jinja2
-
- # @contextfunction renamed to @pass_context in Jinja 3.0, to be removed in 3.1
- if hasattr(jinja2, "pass_context"):
- pass_context = jinja2.pass_context
- else: # pragma: nocover
- pass_context = jinja2.contextfunction
- except ImportError: # pragma: nocover
- jinja2 = None # type: ignore
-
-
- class _TemplateResponse(Response):
- media_type = "text/html"
-
- def __init__(
- self,
- template: typing.Any,
- context: dict,
- status_code: int = 200,
- headers: dict = None,
- media_type: str = None,
- background: BackgroundTask = None,
- ):
- self.template = template
- self.context = context
- content = template.render(context)
- super().__init__(content, status_code, headers, media_type, background)
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- request = self.context.get("request", {})
- extensions = request.get("extensions", {})
- if "http.response.template" in extensions:
- await send(
- {
- "type": "http.response.template",
- "template": self.template,
- "context": self.context,
- }
- )
- await super().__call__(scope, receive, send)
-
-
- class Jinja2Templates:
- """
- templates = Jinja2Templates("templates")
-
- return templates.TemplateResponse("index.html", {"request": request})
- """
-
- def __init__(self, directory: typing.Union[str, PathLike]) -> None:
- assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
- self.env = self._create_env(directory)
-
- def _create_env(
- self, directory: typing.Union[str, PathLike]
- ) -> "jinja2.Environment":
- @pass_context
- def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
- request = context["request"]
- return request.url_for(name, **path_params)
-
- loader = jinja2.FileSystemLoader(directory)
- env = jinja2.Environment(loader=loader, autoescape=True)
- env.globals["url_for"] = url_for
- return env
-
- def get_template(self, name: str) -> "jinja2.Template":
- return self.env.get_template(name)
-
- def TemplateResponse(
- self,
- name: str,
- context: dict,
- status_code: int = 200,
- headers: dict = None,
- media_type: str = None,
- background: BackgroundTask = None,
- ) -> _TemplateResponse:
- if "request" not in context:
- raise ValueError('context must include a "request" key')
- template = self.get_template(name)
- return _TemplateResponse(
- template,
- context,
- status_code=status_code,
- headers=headers,
- media_type=media_type,
- background=background,
- )
|