You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

98 rivejä
3.0 KiB

  1. import typing
  2. from os import PathLike
  3. from starlette.background import BackgroundTask
  4. from starlette.responses import Response
  5. from starlette.types import Receive, Scope, Send
  6. try:
  7. import jinja2
  8. # @contextfunction renamed to @pass_context in Jinja 3.0, to be removed in 3.1
  9. if hasattr(jinja2, "pass_context"):
  10. pass_context = jinja2.pass_context
  11. else: # pragma: nocover
  12. pass_context = jinja2.contextfunction
  13. except ImportError: # pragma: nocover
  14. jinja2 = None # type: ignore
  15. class _TemplateResponse(Response):
  16. media_type = "text/html"
  17. def __init__(
  18. self,
  19. template: typing.Any,
  20. context: dict,
  21. status_code: int = 200,
  22. headers: dict = None,
  23. media_type: str = None,
  24. background: BackgroundTask = None,
  25. ):
  26. self.template = template
  27. self.context = context
  28. content = template.render(context)
  29. super().__init__(content, status_code, headers, media_type, background)
  30. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  31. request = self.context.get("request", {})
  32. extensions = request.get("extensions", {})
  33. if "http.response.template" in extensions:
  34. await send(
  35. {
  36. "type": "http.response.template",
  37. "template": self.template,
  38. "context": self.context,
  39. }
  40. )
  41. await super().__call__(scope, receive, send)
  42. class Jinja2Templates:
  43. """
  44. templates = Jinja2Templates("templates")
  45. return templates.TemplateResponse("index.html", {"request": request})
  46. """
  47. def __init__(self, directory: typing.Union[str, PathLike]) -> None:
  48. assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
  49. self.env = self._create_env(directory)
  50. def _create_env(
  51. self, directory: typing.Union[str, PathLike]
  52. ) -> "jinja2.Environment":
  53. @pass_context
  54. def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
  55. request = context["request"]
  56. return request.url_for(name, **path_params)
  57. loader = jinja2.FileSystemLoader(directory)
  58. env = jinja2.Environment(loader=loader, autoescape=True)
  59. env.globals["url_for"] = url_for
  60. return env
  61. def get_template(self, name: str) -> "jinja2.Template":
  62. return self.env.get_template(name)
  63. def TemplateResponse(
  64. self,
  65. name: str,
  66. context: dict,
  67. status_code: int = 200,
  68. headers: dict = None,
  69. media_type: str = None,
  70. background: BackgroundTask = None,
  71. ) -> _TemplateResponse:
  72. if "request" not in context:
  73. raise ValueError('context must include a "request" key')
  74. template = self.get_template(name)
  75. return _TemplateResponse(
  76. template,
  77. context,
  78. status_code=status_code,
  79. headers=headers,
  80. media_type=media_type,
  81. background=background,
  82. )