You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

218 lines
8.1 KiB

  1. from __future__ import annotations
  2. import warnings
  3. from collections.abc import Mapping, Sequence
  4. from os import PathLike
  5. from typing import Any, Callable, cast, overload
  6. from starlette.background import BackgroundTask
  7. from starlette.datastructures import URL
  8. from starlette.requests import Request
  9. from starlette.responses import HTMLResponse
  10. from starlette.types import Receive, Scope, Send
  11. try:
  12. import jinja2
  13. # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
  14. # hence we try to get pass_context (most installs will be >=3.1)
  15. # and fall back to contextfunction,
  16. # adding a type ignore for mypy to let us access an attribute that may not exist
  17. if hasattr(jinja2, "pass_context"):
  18. pass_context = jinja2.pass_context
  19. else: # pragma: no cover
  20. pass_context = jinja2.contextfunction # type: ignore[attr-defined]
  21. except ModuleNotFoundError: # pragma: no cover
  22. jinja2 = None # type: ignore[assignment]
  23. class _TemplateResponse(HTMLResponse):
  24. def __init__(
  25. self,
  26. template: Any,
  27. context: dict[str, Any],
  28. status_code: int = 200,
  29. headers: Mapping[str, str] | None = None,
  30. media_type: str | None = None,
  31. background: BackgroundTask | None = None,
  32. ):
  33. self.template = template
  34. self.context = context
  35. content = template.render(context)
  36. super().__init__(content, status_code, headers, media_type, background)
  37. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  38. request = self.context.get("request", {})
  39. extensions = request.get("extensions", {})
  40. if "http.response.debug" in extensions: # pragma: no branch
  41. await send(
  42. {
  43. "type": "http.response.debug",
  44. "info": {
  45. "template": self.template,
  46. "context": self.context,
  47. },
  48. }
  49. )
  50. await super().__call__(scope, receive, send)
  51. class Jinja2Templates:
  52. """
  53. templates = Jinja2Templates("templates")
  54. return templates.TemplateResponse("index.html", {"request": request})
  55. """
  56. @overload
  57. def __init__(
  58. self,
  59. directory: str | PathLike[str] | Sequence[str | PathLike[str]],
  60. *,
  61. context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
  62. **env_options: Any,
  63. ) -> None: ...
  64. @overload
  65. def __init__(
  66. self,
  67. *,
  68. env: jinja2.Environment,
  69. context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
  70. ) -> None: ...
  71. def __init__(
  72. self,
  73. directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None,
  74. *,
  75. context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
  76. env: jinja2.Environment | None = None,
  77. **env_options: Any,
  78. ) -> None:
  79. if env_options:
  80. warnings.warn(
  81. "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
  82. DeprecationWarning,
  83. )
  84. assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
  85. assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
  86. self.context_processors = context_processors or []
  87. if directory is not None:
  88. self.env = self._create_env(directory, **env_options)
  89. elif env is not None: # pragma: no branch
  90. self.env = env
  91. self._setup_env_defaults(self.env)
  92. def _create_env(
  93. self,
  94. directory: str | PathLike[str] | Sequence[str | PathLike[str]],
  95. **env_options: Any,
  96. ) -> jinja2.Environment:
  97. loader = jinja2.FileSystemLoader(directory)
  98. env_options.setdefault("loader", loader)
  99. env_options.setdefault("autoescape", True)
  100. return jinja2.Environment(**env_options)
  101. def _setup_env_defaults(self, env: jinja2.Environment) -> None:
  102. @pass_context
  103. def url_for(
  104. context: dict[str, Any],
  105. name: str,
  106. /,
  107. **path_params: Any,
  108. ) -> URL:
  109. request: Request = context["request"]
  110. return request.url_for(name, **path_params)
  111. env.globals.setdefault("url_for", url_for)
  112. def get_template(self, name: str) -> jinja2.Template:
  113. return self.env.get_template(name)
  114. @overload
  115. def TemplateResponse(
  116. self,
  117. request: Request,
  118. name: str,
  119. context: dict[str, Any] | None = None,
  120. status_code: int = 200,
  121. headers: Mapping[str, str] | None = None,
  122. media_type: str | None = None,
  123. background: BackgroundTask | None = None,
  124. ) -> _TemplateResponse: ...
  125. @overload
  126. def TemplateResponse(
  127. self,
  128. name: str,
  129. context: dict[str, Any] | None = None,
  130. status_code: int = 200,
  131. headers: Mapping[str, str] | None = None,
  132. media_type: str | None = None,
  133. background: BackgroundTask | None = None,
  134. ) -> _TemplateResponse:
  135. # Deprecated usage
  136. ...
  137. def TemplateResponse(self, *args: Any, **kwargs: Any) -> _TemplateResponse:
  138. if args:
  139. if isinstance(args[0], str): # the first argument is template name (old style)
  140. warnings.warn(
  141. "The `name` is not the first parameter anymore. "
  142. "The first parameter should be the `Request` instance.\n"
  143. 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
  144. DeprecationWarning,
  145. )
  146. name = args[0]
  147. context = args[1] if len(args) > 1 else kwargs.get("context", {})
  148. status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
  149. headers = args[3] if len(args) > 3 else kwargs.get("headers")
  150. media_type = args[4] if len(args) > 4 else kwargs.get("media_type")
  151. background = args[5] if len(args) > 5 else kwargs.get("background")
  152. if "request" not in context:
  153. raise ValueError('context must include a "request" key')
  154. request = context["request"]
  155. else: # the first argument is a request instance (new style)
  156. request = args[0]
  157. name = args[1] if len(args) > 1 else kwargs["name"]
  158. context = args[2] if len(args) > 2 else kwargs.get("context", {})
  159. status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
  160. headers = args[4] if len(args) > 4 else kwargs.get("headers")
  161. media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
  162. background = args[6] if len(args) > 6 else kwargs.get("background")
  163. else: # all arguments are kwargs
  164. if "request" not in kwargs:
  165. warnings.warn(
  166. "The `TemplateResponse` now requires the `request` argument.\n"
  167. 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
  168. DeprecationWarning,
  169. )
  170. if "request" not in kwargs.get("context", {}):
  171. raise ValueError('context must include a "request" key')
  172. context = kwargs.get("context", {})
  173. request = kwargs.get("request", context.get("request"))
  174. name = cast(str, kwargs["name"])
  175. status_code = kwargs.get("status_code", 200)
  176. headers = kwargs.get("headers")
  177. media_type = kwargs.get("media_type")
  178. background = kwargs.get("background")
  179. context.setdefault("request", request)
  180. for context_processor in self.context_processors:
  181. context.update(context_processor(request))
  182. template = self.get_template(name)
  183. return _TemplateResponse(
  184. template,
  185. context,
  186. status_code=status_code,
  187. headers=headers,
  188. media_type=media_type,
  189. background=background,
  190. )