|
- from __future__ import annotations
-
- import io
- import math
- import sys
- import warnings
- from collections.abc import MutableMapping
- from typing import Any, Callable
-
- import anyio
- from anyio.abc import ObjectReceiveStream, ObjectSendStream
-
- from starlette.types import Receive, Scope, Send
-
- warnings.warn(
- "starlette.middleware.wsgi is deprecated and will be removed in a future release. "
- "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
- DeprecationWarning,
- )
-
-
- def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
- """
- Builds a scope and request body into a WSGI environ object.
- """
-
- script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
- path_info = scope["path"].encode("utf8").decode("latin1")
- if path_info.startswith(script_name):
- path_info = path_info[len(script_name) :]
-
- environ = {
- "REQUEST_METHOD": scope["method"],
- "SCRIPT_NAME": script_name,
- "PATH_INFO": path_info,
- "QUERY_STRING": scope["query_string"].decode("ascii"),
- "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
- "wsgi.version": (1, 0),
- "wsgi.url_scheme": scope.get("scheme", "http"),
- "wsgi.input": io.BytesIO(body),
- "wsgi.errors": sys.stdout,
- "wsgi.multithread": True,
- "wsgi.multiprocess": True,
- "wsgi.run_once": False,
- }
-
- # Get server name and port - required in WSGI, not in ASGI
- server = scope.get("server") or ("localhost", 80)
- environ["SERVER_NAME"] = server[0]
- environ["SERVER_PORT"] = server[1]
-
- # Get client IP address
- if scope.get("client"):
- environ["REMOTE_ADDR"] = scope["client"][0]
-
- # Go through headers and make them into environ entries
- for name, value in scope.get("headers", []):
- name = name.decode("latin1")
- if name == "content-length":
- corrected_name = "CONTENT_LENGTH"
- elif name == "content-type":
- corrected_name = "CONTENT_TYPE"
- else:
- corrected_name = f"HTTP_{name}".upper().replace("-", "_")
- # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
- # case
- value = value.decode("latin1")
- if corrected_name in environ:
- value = environ[corrected_name] + "," + value
- environ[corrected_name] = value
- return environ
-
-
- class WSGIMiddleware:
- def __init__(self, app: Callable[..., Any]) -> None:
- self.app = app
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- assert scope["type"] == "http"
- responder = WSGIResponder(self.app, scope)
- await responder(receive, send)
-
-
- class WSGIResponder:
- stream_send: ObjectSendStream[MutableMapping[str, Any]]
- stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
-
- def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
- self.app = app
- self.scope = scope
- self.status = None
- self.response_headers = None
- self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
- self.response_started = False
- self.exc_info: Any = None
-
- async def __call__(self, receive: Receive, send: Send) -> None:
- body = b""
- more_body = True
- while more_body:
- message = await receive()
- body += message.get("body", b"")
- more_body = message.get("more_body", False)
- environ = build_environ(self.scope, body)
-
- async with anyio.create_task_group() as task_group:
- task_group.start_soon(self.sender, send)
- async with self.stream_send:
- await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
- if self.exc_info is not None:
- raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
-
- async def sender(self, send: Send) -> None:
- async with self.stream_receive:
- async for message in self.stream_receive:
- await send(message)
-
- def start_response(
- self,
- status: str,
- response_headers: list[tuple[str, str]],
- exc_info: Any = None,
- ) -> None:
- self.exc_info = exc_info
- if not self.response_started: # pragma: no branch
- self.response_started = True
- status_code_string, _ = status.split(" ", 1)
- status_code = int(status_code_string)
- headers = [
- (name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
- for name, value in response_headers
- ]
- anyio.from_thread.run(
- self.stream_send.send,
- {
- "type": "http.response.start",
- "status": status_code,
- "headers": headers,
- },
- )
-
- def wsgi(
- self,
- environ: dict[str, Any],
- start_response: Callable[..., Any],
- ) -> None:
- for chunk in self.app(environ, start_response):
- anyio.from_thread.run(
- self.stream_send.send,
- {"type": "http.response.body", "body": chunk, "more_body": True},
- )
-
- anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|