|
- from __future__ import annotations
-
- from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
- from shlex import shlex
- from typing import (
- Any,
- BinaryIO,
- NamedTuple,
- TypeVar,
- Union,
- cast,
- )
- from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
-
- from starlette.concurrency import run_in_threadpool
- from starlette.types import Scope
-
-
- class Address(NamedTuple):
- host: str
- port: int
-
-
- _KeyType = TypeVar("_KeyType")
- # Mapping keys are invariant but their values are covariant since
- # you can only read them
- # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
- _CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
-
-
- class URL:
- def __init__(
- self,
- url: str = "",
- scope: Scope | None = None,
- **components: Any,
- ) -> None:
- if scope is not None:
- assert not url, 'Cannot set both "url" and "scope".'
- assert not components, 'Cannot set both "scope" and "**components".'
- scheme = scope.get("scheme", "http")
- server = scope.get("server", None)
- path = scope["path"]
- query_string = scope.get("query_string", b"")
-
- host_header = None
- for key, value in scope["headers"]:
- if key == b"host":
- host_header = value.decode("latin-1")
- break
-
- if host_header is not None:
- url = f"{scheme}://{host_header}{path}"
- elif server is None:
- url = path
- else:
- host, port = server
- default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
- if port == default_port:
- url = f"{scheme}://{host}{path}"
- else:
- url = f"{scheme}://{host}:{port}{path}"
-
- if query_string:
- url += "?" + query_string.decode()
- elif components:
- assert not url, 'Cannot set both "url" and "**components".'
- url = URL("").replace(**components).components.geturl()
-
- self._url = url
-
- @property
- def components(self) -> SplitResult:
- if not hasattr(self, "_components"):
- self._components = urlsplit(self._url)
- return self._components
-
- @property
- def scheme(self) -> str:
- return self.components.scheme
-
- @property
- def netloc(self) -> str:
- return self.components.netloc
-
- @property
- def path(self) -> str:
- return self.components.path
-
- @property
- def query(self) -> str:
- return self.components.query
-
- @property
- def fragment(self) -> str:
- return self.components.fragment
-
- @property
- def username(self) -> None | str:
- return self.components.username
-
- @property
- def password(self) -> None | str:
- return self.components.password
-
- @property
- def hostname(self) -> None | str:
- return self.components.hostname
-
- @property
- def port(self) -> int | None:
- return self.components.port
-
- @property
- def is_secure(self) -> bool:
- return self.scheme in ("https", "wss")
-
- def replace(self, **kwargs: Any) -> URL:
- if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
- hostname = kwargs.pop("hostname", None)
- port = kwargs.pop("port", self.port)
- username = kwargs.pop("username", self.username)
- password = kwargs.pop("password", self.password)
-
- if hostname is None:
- netloc = self.netloc
- _, _, hostname = netloc.rpartition("@")
-
- if hostname[-1] != "]":
- hostname = hostname.rsplit(":", 1)[0]
-
- netloc = hostname
- if port is not None:
- netloc += f":{port}"
- if username is not None:
- userpass = username
- if password is not None:
- userpass += f":{password}"
- netloc = f"{userpass}@{netloc}"
-
- kwargs["netloc"] = netloc
-
- components = self.components._replace(**kwargs)
- return self.__class__(components.geturl())
-
- def include_query_params(self, **kwargs: Any) -> URL:
- params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
- params.update({str(key): str(value) for key, value in kwargs.items()})
- query = urlencode(params.multi_items())
- return self.replace(query=query)
-
- def replace_query_params(self, **kwargs: Any) -> URL:
- query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
- return self.replace(query=query)
-
- def remove_query_params(self, keys: str | Sequence[str]) -> URL:
- if isinstance(keys, str):
- keys = [keys]
- params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
- for key in keys:
- params.pop(key, None)
- query = urlencode(params.multi_items())
- return self.replace(query=query)
-
- def __eq__(self, other: Any) -> bool:
- return str(self) == str(other)
-
- def __str__(self) -> str:
- return self._url
-
- def __repr__(self) -> str:
- url = str(self)
- if self.password:
- url = str(self.replace(password="********"))
- return f"{self.__class__.__name__}({repr(url)})"
-
-
- class URLPath(str):
- """
- A URL path string that may also hold an associated protocol and/or host.
- Used by the routing to return `url_path_for` matches.
- """
-
- def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
- assert protocol in ("http", "websocket", "")
- return str.__new__(cls, path)
-
- def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
- self.protocol = protocol
- self.host = host
-
- def make_absolute_url(self, base_url: str | URL) -> URL:
- if isinstance(base_url, str):
- base_url = URL(base_url)
- if self.protocol:
- scheme = {
- "http": {True: "https", False: "http"},
- "websocket": {True: "wss", False: "ws"},
- }[self.protocol][base_url.is_secure]
- else:
- scheme = base_url.scheme
-
- netloc = self.host or base_url.netloc
- path = base_url.path.rstrip("/") + str(self)
- return URL(scheme=scheme, netloc=netloc, path=path)
-
-
- class Secret:
- """
- Holds a string value that should not be revealed in tracebacks etc.
- You should cast the value to `str` at the point it is required.
- """
-
- def __init__(self, value: str):
- self._value = value
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- return f"{class_name}('**********')"
-
- def __str__(self) -> str:
- return self._value
-
- def __bool__(self) -> bool:
- return bool(self._value)
-
-
- class CommaSeparatedStrings(Sequence[str]):
- def __init__(self, value: str | Sequence[str]):
- if isinstance(value, str):
- splitter = shlex(value, posix=True)
- splitter.whitespace = ","
- splitter.whitespace_split = True
- self._items = [item.strip() for item in splitter]
- else:
- self._items = list(value)
-
- def __len__(self) -> int:
- return len(self._items)
-
- def __getitem__(self, index: int | slice) -> Any:
- return self._items[index]
-
- def __iter__(self) -> Iterator[str]:
- return iter(self._items)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- items = [item for item in self]
- return f"{class_name}({items!r})"
-
- def __str__(self) -> str:
- return ", ".join(repr(item) for item in self)
-
-
- class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
- _dict: dict[_KeyType, _CovariantValueType]
-
- def __init__(
- self,
- *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
- | Mapping[_KeyType, _CovariantValueType]
- | Iterable[tuple[_KeyType, _CovariantValueType]],
- **kwargs: Any,
- ) -> None:
- assert len(args) < 2, "Too many arguments."
-
- value: Any = args[0] if args else []
- if kwargs:
- value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
-
- if not value:
- _items: list[tuple[Any, Any]] = []
- elif hasattr(value, "multi_items"):
- value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
- _items = list(value.multi_items())
- elif hasattr(value, "items"):
- value = cast(Mapping[_KeyType, _CovariantValueType], value)
- _items = list(value.items())
- else:
- value = cast("list[tuple[Any, Any]]", value)
- _items = list(value)
-
- self._dict = {k: v for k, v in _items}
- self._list = _items
-
- def getlist(self, key: Any) -> list[_CovariantValueType]:
- return [item_value for item_key, item_value in self._list if item_key == key]
-
- def keys(self) -> KeysView[_KeyType]:
- return self._dict.keys()
-
- def values(self) -> ValuesView[_CovariantValueType]:
- return self._dict.values()
-
- def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
- return self._dict.items()
-
- def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
- return list(self._list)
-
- def __getitem__(self, key: _KeyType) -> _CovariantValueType:
- return self._dict[key]
-
- def __contains__(self, key: Any) -> bool:
- return key in self._dict
-
- def __iter__(self) -> Iterator[_KeyType]:
- return iter(self.keys())
-
- def __len__(self) -> int:
- return len(self._dict)
-
- def __eq__(self, other: Any) -> bool:
- if not isinstance(other, self.__class__):
- return False
- return sorted(self._list) == sorted(other._list)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- items = self.multi_items()
- return f"{class_name}({items!r})"
-
-
- class MultiDict(ImmutableMultiDict[Any, Any]):
- def __setitem__(self, key: Any, value: Any) -> None:
- self.setlist(key, [value])
-
- def __delitem__(self, key: Any) -> None:
- self._list = [(k, v) for k, v in self._list if k != key]
- del self._dict[key]
-
- def pop(self, key: Any, default: Any = None) -> Any:
- self._list = [(k, v) for k, v in self._list if k != key]
- return self._dict.pop(key, default)
-
- def popitem(self) -> tuple[Any, Any]:
- key, value = self._dict.popitem()
- self._list = [(k, v) for k, v in self._list if k != key]
- return key, value
-
- def poplist(self, key: Any) -> list[Any]:
- values = [v for k, v in self._list if k == key]
- self.pop(key)
- return values
-
- def clear(self) -> None:
- self._dict.clear()
- self._list.clear()
-
- def setdefault(self, key: Any, default: Any = None) -> Any:
- if key not in self:
- self._dict[key] = default
- self._list.append((key, default))
-
- return self[key]
-
- def setlist(self, key: Any, values: list[Any]) -> None:
- if not values:
- self.pop(key, None)
- else:
- existing_items = [(k, v) for (k, v) in self._list if k != key]
- self._list = existing_items + [(key, value) for value in values]
- self._dict[key] = values[-1]
-
- def append(self, key: Any, value: Any) -> None:
- self._list.append((key, value))
- self._dict[key] = value
-
- def update(
- self,
- *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
- **kwargs: Any,
- ) -> None:
- value = MultiDict(*args, **kwargs)
- existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
- self._list = existing_items + value.multi_items()
- self._dict.update(value)
-
-
- class QueryParams(ImmutableMultiDict[str, str]):
- """
- An immutable multidict.
- """
-
- def __init__(
- self,
- *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
- **kwargs: Any,
- ) -> None:
- assert len(args) < 2, "Too many arguments."
-
- value = args[0] if args else []
-
- if isinstance(value, str):
- super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
- elif isinstance(value, bytes):
- super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
- else:
- super().__init__(*args, **kwargs) # type: ignore[arg-type]
- self._list = [(str(k), str(v)) for k, v in self._list]
- self._dict = {str(k): str(v) for k, v in self._dict.items()}
-
- def __str__(self) -> str:
- return urlencode(self._list)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- query_string = str(self)
- return f"{class_name}({query_string!r})"
-
-
- class UploadFile:
- """
- An uploaded file included as part of the request data.
- """
-
- def __init__(
- self,
- file: BinaryIO,
- *,
- size: int | None = None,
- filename: str | None = None,
- headers: Headers | None = None,
- ) -> None:
- self.filename = filename
- self.file = file
- self.size = size
- self.headers = headers or Headers()
-
- # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
- # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
- self._max_mem_size = getattr(self.file, "_max_size", 0)
-
- @property
- def content_type(self) -> str | None:
- return self.headers.get("content-type", None)
-
- @property
- def _in_memory(self) -> bool:
- # check for SpooledTemporaryFile._rolled
- rolled_to_disk = getattr(self.file, "_rolled", True)
- return not rolled_to_disk
-
- def _will_roll(self, size_to_add: int) -> bool:
- # If we're not in_memory then we will always roll
- if not self._in_memory:
- return True
-
- # Check for SpooledTemporaryFile._max_size
- future_size = self.file.tell() + size_to_add
- return bool(future_size > self._max_mem_size) if self._max_mem_size else False
-
- async def write(self, data: bytes) -> None:
- new_data_len = len(data)
- if self.size is not None:
- self.size += new_data_len
-
- if self._will_roll(new_data_len):
- await run_in_threadpool(self.file.write, data)
- else:
- self.file.write(data)
-
- async def read(self, size: int = -1) -> bytes:
- if self._in_memory:
- return self.file.read(size)
- return await run_in_threadpool(self.file.read, size)
-
- async def seek(self, offset: int) -> None:
- if self._in_memory:
- self.file.seek(offset)
- else:
- await run_in_threadpool(self.file.seek, offset)
-
- async def close(self) -> None:
- if self._in_memory:
- self.file.close()
- else:
- await run_in_threadpool(self.file.close)
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
-
-
- class FormData(ImmutableMultiDict[str, Union[UploadFile, str]]):
- """
- An immutable multidict, containing both file uploads and text input.
- """
-
- def __init__(
- self,
- *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
- **kwargs: str | UploadFile,
- ) -> None:
- super().__init__(*args, **kwargs)
-
- async def close(self) -> None:
- for key, value in self.multi_items():
- if isinstance(value, UploadFile):
- await value.close()
-
-
- class Headers(Mapping[str, str]):
- """
- An immutable, case-insensitive multidict.
- """
-
- def __init__(
- self,
- headers: Mapping[str, str] | None = None,
- raw: list[tuple[bytes, bytes]] | None = None,
- scope: MutableMapping[str, Any] | None = None,
- ) -> None:
- self._list: list[tuple[bytes, bytes]] = []
- if headers is not None:
- assert raw is None, 'Cannot set both "headers" and "raw".'
- assert scope is None, 'Cannot set both "headers" and "scope".'
- self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
- elif raw is not None:
- assert scope is None, 'Cannot set both "raw" and "scope".'
- self._list = raw
- elif scope is not None:
- # scope["headers"] isn't necessarily a list
- # it might be a tuple or other iterable
- self._list = scope["headers"] = list(scope["headers"])
-
- @property
- def raw(self) -> list[tuple[bytes, bytes]]:
- return list(self._list)
-
- def keys(self) -> list[str]: # type: ignore[override]
- return [key.decode("latin-1") for key, value in self._list]
-
- def values(self) -> list[str]: # type: ignore[override]
- return [value.decode("latin-1") for key, value in self._list]
-
- def items(self) -> list[tuple[str, str]]: # type: ignore[override]
- return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
-
- def getlist(self, key: str) -> list[str]:
- get_header_key = key.lower().encode("latin-1")
- return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
-
- def mutablecopy(self) -> MutableHeaders:
- return MutableHeaders(raw=self._list[:])
-
- def __getitem__(self, key: str) -> str:
- get_header_key = key.lower().encode("latin-1")
- for header_key, header_value in self._list:
- if header_key == get_header_key:
- return header_value.decode("latin-1")
- raise KeyError(key)
-
- def __contains__(self, key: Any) -> bool:
- get_header_key = key.lower().encode("latin-1")
- for header_key, header_value in self._list:
- if header_key == get_header_key:
- return True
- return False
-
- def __iter__(self) -> Iterator[Any]:
- return iter(self.keys())
-
- def __len__(self) -> int:
- return len(self._list)
-
- def __eq__(self, other: Any) -> bool:
- if not isinstance(other, Headers):
- return False
- return sorted(self._list) == sorted(other._list)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- as_dict = dict(self.items())
- if len(as_dict) == len(self):
- return f"{class_name}({as_dict!r})"
- return f"{class_name}(raw={self.raw!r})"
-
-
- class MutableHeaders(Headers):
- def __setitem__(self, key: str, value: str) -> None:
- """
- Set the header `key` to `value`, removing any duplicate entries.
- Retains insertion order.
- """
- set_key = key.lower().encode("latin-1")
- set_value = value.encode("latin-1")
-
- found_indexes: list[int] = []
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == set_key:
- found_indexes.append(idx)
-
- for idx in reversed(found_indexes[1:]):
- del self._list[idx]
-
- if found_indexes:
- idx = found_indexes[0]
- self._list[idx] = (set_key, set_value)
- else:
- self._list.append((set_key, set_value))
-
- def __delitem__(self, key: str) -> None:
- """
- Remove the header `key`.
- """
- del_key = key.lower().encode("latin-1")
-
- pop_indexes: list[int] = []
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == del_key:
- pop_indexes.append(idx)
-
- for idx in reversed(pop_indexes):
- del self._list[idx]
-
- def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
- if not isinstance(other, Mapping):
- raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
- self.update(other)
- return self
-
- def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
- if not isinstance(other, Mapping):
- raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
- new = self.mutablecopy()
- new.update(other)
- return new
-
- @property
- def raw(self) -> list[tuple[bytes, bytes]]:
- return self._list
-
- def setdefault(self, key: str, value: str) -> str:
- """
- If the header `key` does not exist, then set it to `value`.
- Returns the header value.
- """
- set_key = key.lower().encode("latin-1")
- set_value = value.encode("latin-1")
-
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == set_key:
- return item_value.decode("latin-1")
- self._list.append((set_key, set_value))
- return value
-
- def update(self, other: Mapping[str, str]) -> None:
- for key, val in other.items():
- self[key] = val
-
- def append(self, key: str, value: str) -> None:
- """
- Append a header, preserving any duplicate entries.
- """
- append_key = key.lower().encode("latin-1")
- append_value = value.encode("latin-1")
- self._list.append((append_key, append_value))
-
- def add_vary_header(self, vary: str) -> None:
- existing = self.get("vary")
- if existing is not None:
- vary = ", ".join([existing, vary])
- self["vary"] = vary
-
-
- class State:
- """
- An object that can be used to store arbitrary state.
-
- Used for `request.state` and `app.state`.
- """
-
- _state: dict[str, Any]
-
- def __init__(self, state: dict[str, Any] | None = None):
- if state is None:
- state = {}
- super().__setattr__("_state", state)
-
- def __setattr__(self, key: Any, value: Any) -> None:
- self._state[key] = value
-
- def __getattr__(self, key: Any) -> Any:
- try:
- return self._state[key]
- except KeyError:
- message = "'{}' object has no attribute '{}'"
- raise AttributeError(message.format(self.__class__.__name__, key))
-
- def __delattr__(self, key: Any) -> None:
- del self._state[key]
|