Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

154 řádky
5.2 KiB

  1. from __future__ import annotations
  2. import io
  3. import math
  4. import sys
  5. import warnings
  6. from collections.abc import MutableMapping
  7. from typing import Any, Callable
  8. import anyio
  9. from anyio.abc import ObjectReceiveStream, ObjectSendStream
  10. from starlette.types import Receive, Scope, Send
  11. warnings.warn(
  12. "starlette.middleware.wsgi is deprecated and will be removed in a future release. "
  13. "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
  14. DeprecationWarning,
  15. )
  16. def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
  17. """
  18. Builds a scope and request body into a WSGI environ object.
  19. """
  20. script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
  21. path_info = scope["path"].encode("utf8").decode("latin1")
  22. if path_info.startswith(script_name):
  23. path_info = path_info[len(script_name) :]
  24. environ = {
  25. "REQUEST_METHOD": scope["method"],
  26. "SCRIPT_NAME": script_name,
  27. "PATH_INFO": path_info,
  28. "QUERY_STRING": scope["query_string"].decode("ascii"),
  29. "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
  30. "wsgi.version": (1, 0),
  31. "wsgi.url_scheme": scope.get("scheme", "http"),
  32. "wsgi.input": io.BytesIO(body),
  33. "wsgi.errors": sys.stdout,
  34. "wsgi.multithread": True,
  35. "wsgi.multiprocess": True,
  36. "wsgi.run_once": False,
  37. }
  38. # Get server name and port - required in WSGI, not in ASGI
  39. server = scope.get("server") or ("localhost", 80)
  40. environ["SERVER_NAME"] = server[0]
  41. environ["SERVER_PORT"] = server[1]
  42. # Get client IP address
  43. if scope.get("client"):
  44. environ["REMOTE_ADDR"] = scope["client"][0]
  45. # Go through headers and make them into environ entries
  46. for name, value in scope.get("headers", []):
  47. name = name.decode("latin1")
  48. if name == "content-length":
  49. corrected_name = "CONTENT_LENGTH"
  50. elif name == "content-type":
  51. corrected_name = "CONTENT_TYPE"
  52. else:
  53. corrected_name = f"HTTP_{name}".upper().replace("-", "_")
  54. # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
  55. # case
  56. value = value.decode("latin1")
  57. if corrected_name in environ:
  58. value = environ[corrected_name] + "," + value
  59. environ[corrected_name] = value
  60. return environ
  61. class WSGIMiddleware:
  62. def __init__(self, app: Callable[..., Any]) -> None:
  63. self.app = app
  64. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  65. assert scope["type"] == "http"
  66. responder = WSGIResponder(self.app, scope)
  67. await responder(receive, send)
  68. class WSGIResponder:
  69. stream_send: ObjectSendStream[MutableMapping[str, Any]]
  70. stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
  71. def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
  72. self.app = app
  73. self.scope = scope
  74. self.status = None
  75. self.response_headers = None
  76. self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
  77. self.response_started = False
  78. self.exc_info: Any = None
  79. async def __call__(self, receive: Receive, send: Send) -> None:
  80. body = b""
  81. more_body = True
  82. while more_body:
  83. message = await receive()
  84. body += message.get("body", b"")
  85. more_body = message.get("more_body", False)
  86. environ = build_environ(self.scope, body)
  87. async with anyio.create_task_group() as task_group:
  88. task_group.start_soon(self.sender, send)
  89. async with self.stream_send:
  90. await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
  91. if self.exc_info is not None:
  92. raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
  93. async def sender(self, send: Send) -> None:
  94. async with self.stream_receive:
  95. async for message in self.stream_receive:
  96. await send(message)
  97. def start_response(
  98. self,
  99. status: str,
  100. response_headers: list[tuple[str, str]],
  101. exc_info: Any = None,
  102. ) -> None:
  103. self.exc_info = exc_info
  104. if not self.response_started: # pragma: no branch
  105. self.response_started = True
  106. status_code_string, _ = status.split(" ", 1)
  107. status_code = int(status_code_string)
  108. headers = [
  109. (name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
  110. for name, value in response_headers
  111. ]
  112. anyio.from_thread.run(
  113. self.stream_send.send,
  114. {
  115. "type": "http.response.start",
  116. "status": status_code,
  117. "headers": headers,
  118. },
  119. )
  120. def wsgi(
  121. self,
  122. environ: dict[str, Any],
  123. start_response: Callable[..., Any],
  124. ) -> None:
  125. for chunk in self.app(environ, start_response):
  126. anyio.from_thread.run(
  127. self.stream_send.send,
  128. {"type": "http.response.body", "body": chunk, "more_body": True},
  129. )
  130. anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})