|
- import hashlib
- import http.cookies
- import json
- import os
- import stat
- import sys
- import typing
- from email.utils import formatdate
- from functools import partial
- from mimetypes import guess_type as mimetypes_guess_type
- from urllib.parse import quote
-
- import anyio
-
- from starlette.background import BackgroundTask
- from starlette.concurrency import iterate_in_threadpool
- from starlette.datastructures import URL, MutableHeaders
- from starlette.types import Receive, Scope, Send
-
- # Workaround for adding samesite support to pre 3.8 python
- http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore
-
-
- # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on <py3.8
- def guess_type(
- url: typing.Union[str, "os.PathLike[str]"], strict: bool = True
- ) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
- if sys.version_info < (3, 8): # pragma: no cover
- url = os.fspath(url)
- return mimetypes_guess_type(url, strict)
-
-
- class Response:
- media_type = None
- charset = "utf-8"
-
- def __init__(
- self,
- content: typing.Any = None,
- status_code: int = 200,
- headers: dict = None,
- media_type: str = None,
- background: BackgroundTask = None,
- ) -> None:
- self.status_code = status_code
- if media_type is not None:
- self.media_type = media_type
- self.background = background
- self.body = self.render(content)
- self.init_headers(headers)
-
- def render(self, content: typing.Any) -> bytes:
- if content is None:
- return b""
- if isinstance(content, bytes):
- return content
- return content.encode(self.charset)
-
- def init_headers(self, headers: typing.Mapping[str, str] = None) -> None:
- if headers is None:
- raw_headers: typing.List[typing.Tuple[bytes, bytes]] = []
- populate_content_length = True
- populate_content_type = True
- else:
- raw_headers = [
- (k.lower().encode("latin-1"), v.encode("latin-1"))
- for k, v in headers.items()
- ]
- keys = [h[0] for h in raw_headers]
- populate_content_length = b"content-length" not in keys
- populate_content_type = b"content-type" not in keys
-
- body = getattr(self, "body", b"")
- if body and populate_content_length:
- content_length = str(len(body))
- raw_headers.append((b"content-length", content_length.encode("latin-1")))
-
- content_type = self.media_type
- if content_type is not None and populate_content_type:
- if content_type.startswith("text/"):
- content_type += "; charset=" + self.charset
- raw_headers.append((b"content-type", content_type.encode("latin-1")))
-
- self.raw_headers = raw_headers
-
- @property
- def headers(self) -> MutableHeaders:
- if not hasattr(self, "_headers"):
- self._headers = MutableHeaders(raw=self.raw_headers)
- return self._headers
-
- def set_cookie(
- self,
- key: str,
- value: str = "",
- max_age: int = None,
- expires: int = None,
- path: str = "/",
- domain: str = None,
- secure: bool = False,
- httponly: bool = False,
- samesite: str = "lax",
- ) -> None:
- cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie()
- cookie[key] = value
- if max_age is not None:
- cookie[key]["max-age"] = max_age
- if expires is not None:
- cookie[key]["expires"] = expires
- if path is not None:
- cookie[key]["path"] = path
- if domain is not None:
- cookie[key]["domain"] = domain
- if secure:
- cookie[key]["secure"] = True
- if httponly:
- cookie[key]["httponly"] = True
- if samesite is not None:
- assert samesite.lower() in [
- "strict",
- "lax",
- "none",
- ], "samesite must be either 'strict', 'lax' or 'none'"
- cookie[key]["samesite"] = samesite
- cookie_val = cookie.output(header="").strip()
- self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
-
- def delete_cookie(
- self,
- key: str,
- path: str = "/",
- domain: str = None,
- secure: bool = False,
- httponly: bool = False,
- samesite: str = "lax",
- ) -> None:
- self.set_cookie(
- key,
- max_age=0,
- expires=0,
- path=path,
- domain=domain,
- secure=secure,
- httponly=httponly,
- samesite=samesite,
- )
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await send(
- {
- "type": "http.response.start",
- "status": self.status_code,
- "headers": self.raw_headers,
- }
- )
- await send({"type": "http.response.body", "body": self.body})
-
- if self.background is not None:
- await self.background()
-
-
- class HTMLResponse(Response):
- media_type = "text/html"
-
-
- class PlainTextResponse(Response):
- media_type = "text/plain"
-
-
- class JSONResponse(Response):
- media_type = "application/json"
-
- def render(self, content: typing.Any) -> bytes:
- return json.dumps(
- content,
- ensure_ascii=False,
- allow_nan=False,
- indent=None,
- separators=(",", ":"),
- ).encode("utf-8")
-
-
- class RedirectResponse(Response):
- def __init__(
- self,
- url: typing.Union[str, URL],
- status_code: int = 307,
- headers: dict = None,
- background: BackgroundTask = None,
- ) -> None:
- super().__init__(
- content=b"", status_code=status_code, headers=headers, background=background
- )
- self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
-
-
- class StreamingResponse(Response):
- def __init__(
- self,
- content: typing.Any,
- status_code: int = 200,
- headers: dict = None,
- media_type: str = None,
- background: BackgroundTask = None,
- ) -> None:
- if isinstance(content, typing.AsyncIterable):
- self.body_iterator = content
- else:
- self.body_iterator = iterate_in_threadpool(content)
- self.status_code = status_code
- self.media_type = self.media_type if media_type is None else media_type
- self.background = background
- self.init_headers(headers)
-
- async def listen_for_disconnect(self, receive: Receive) -> None:
- while True:
- message = await receive()
- if message["type"] == "http.disconnect":
- break
-
- async def stream_response(self, send: Send) -> None:
- await send(
- {
- "type": "http.response.start",
- "status": self.status_code,
- "headers": self.raw_headers,
- }
- )
- async for chunk in self.body_iterator:
- if not isinstance(chunk, bytes):
- chunk = chunk.encode(self.charset)
- await send({"type": "http.response.body", "body": chunk, "more_body": True})
-
- await send({"type": "http.response.body", "body": b"", "more_body": False})
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- async with anyio.create_task_group() as task_group:
-
- async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
- await func()
- task_group.cancel_scope.cancel()
-
- task_group.start_soon(wrap, partial(self.stream_response, send))
- await wrap(partial(self.listen_for_disconnect, receive))
-
- if self.background is not None:
- await self.background()
-
-
- class FileResponse(Response):
- chunk_size = 4096
-
- def __init__(
- self,
- path: typing.Union[str, "os.PathLike[str]"],
- status_code: int = 200,
- headers: dict = None,
- media_type: str = None,
- background: BackgroundTask = None,
- filename: str = None,
- stat_result: os.stat_result = None,
- method: str = None,
- ) -> None:
- self.path = path
- self.status_code = status_code
- self.filename = filename
- self.send_header_only = method is not None and method.upper() == "HEAD"
- if media_type is None:
- media_type = guess_type(filename or path)[0] or "text/plain"
- self.media_type = media_type
- self.background = background
- self.init_headers(headers)
- if self.filename is not None:
- content_disposition_filename = quote(self.filename)
- if content_disposition_filename != self.filename:
- content_disposition = "attachment; filename*=utf-8''{}".format(
- content_disposition_filename
- )
- else:
- content_disposition = f'attachment; filename="{self.filename}"'
- self.headers.setdefault("content-disposition", content_disposition)
- self.stat_result = stat_result
- if stat_result is not None:
- self.set_stat_headers(stat_result)
-
- def set_stat_headers(self, stat_result: os.stat_result) -> None:
- content_length = str(stat_result.st_size)
- last_modified = formatdate(stat_result.st_mtime, usegmt=True)
- etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
- etag = hashlib.md5(etag_base.encode()).hexdigest()
-
- self.headers.setdefault("content-length", content_length)
- self.headers.setdefault("last-modified", last_modified)
- self.headers.setdefault("etag", etag)
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if self.stat_result is None:
- try:
- stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
- self.set_stat_headers(stat_result)
- except FileNotFoundError:
- raise RuntimeError(f"File at path {self.path} does not exist.")
- else:
- mode = stat_result.st_mode
- if not stat.S_ISREG(mode):
- raise RuntimeError(f"File at path {self.path} is not a file.")
- await send(
- {
- "type": "http.response.start",
- "status": self.status_code,
- "headers": self.raw_headers,
- }
- )
- if self.send_header_only:
- await send({"type": "http.response.body", "body": b"", "more_body": False})
- else:
- async with await anyio.open_file(self.path, mode="rb") as file:
- more_body = True
- while more_body:
- chunk = await file.read(self.chunk_size)
- more_body = len(chunk) == self.chunk_size
- await send(
- {
- "type": "http.response.body",
- "body": chunk,
- "more_body": more_body,
- }
- )
- if self.background is not None:
- await self.background()
|