|
- import tempfile
- import typing
- from collections import namedtuple
- from collections.abc import Sequence
- from shlex import shlex
- from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
-
- from starlette.concurrency import run_in_threadpool
- from starlette.types import Scope
-
- Address = namedtuple("Address", ["host", "port"])
-
-
- class URL:
- def __init__(
- self, url: str = "", scope: Scope = None, **components: typing.Any
- ) -> None:
- if scope is not None:
- assert not url, 'Cannot set both "url" and "scope".'
- assert not components, 'Cannot set both "scope" and "**components".'
- scheme = scope.get("scheme", "http")
- server = scope.get("server", None)
- path = scope.get("root_path", "") + scope["path"]
- query_string = scope.get("query_string", b"")
-
- host_header = None
- for key, value in scope["headers"]:
- if key == b"host":
- host_header = value.decode("latin-1")
- break
-
- if host_header is not None:
- url = f"{scheme}://{host_header}{path}"
- elif server is None:
- url = path
- else:
- host, port = server
- default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
- if port == default_port:
- url = f"{scheme}://{host}{path}"
- else:
- url = f"{scheme}://{host}:{port}{path}"
-
- if query_string:
- url += "?" + query_string.decode()
- elif components:
- assert not url, 'Cannot set both "url" and "**components".'
- url = URL("").replace(**components).components.geturl()
-
- self._url = url
-
- @property
- def components(self) -> SplitResult:
- if not hasattr(self, "_components"):
- self._components = urlsplit(self._url)
- return self._components
-
- @property
- def scheme(self) -> str:
- return self.components.scheme
-
- @property
- def netloc(self) -> str:
- return self.components.netloc
-
- @property
- def path(self) -> str:
- return self.components.path
-
- @property
- def query(self) -> str:
- return self.components.query
-
- @property
- def fragment(self) -> str:
- return self.components.fragment
-
- @property
- def username(self) -> typing.Union[None, str]:
- return self.components.username
-
- @property
- def password(self) -> typing.Union[None, str]:
- return self.components.password
-
- @property
- def hostname(self) -> typing.Union[None, str]:
- return self.components.hostname
-
- @property
- def port(self) -> typing.Optional[int]:
- return self.components.port
-
- @property
- def is_secure(self) -> bool:
- return self.scheme in ("https", "wss")
-
- def replace(self, **kwargs: typing.Any) -> "URL":
- if (
- "username" in kwargs
- or "password" in kwargs
- or "hostname" in kwargs
- or "port" in kwargs
- ):
- hostname = kwargs.pop("hostname", self.hostname)
- port = kwargs.pop("port", self.port)
- username = kwargs.pop("username", self.username)
- password = kwargs.pop("password", self.password)
-
- netloc = hostname
- if port is not None:
- netloc += f":{port}"
- if username is not None:
- userpass = username
- if password is not None:
- userpass += f":{password}"
- netloc = f"{userpass}@{netloc}"
-
- kwargs["netloc"] = netloc
-
- components = self.components._replace(**kwargs)
- return self.__class__(components.geturl())
-
- def include_query_params(self, **kwargs: typing.Any) -> "URL":
- params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
- params.update({str(key): str(value) for key, value in kwargs.items()})
- query = urlencode(params.multi_items())
- return self.replace(query=query)
-
- def replace_query_params(self, **kwargs: typing.Any) -> "URL":
- query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
- return self.replace(query=query)
-
- def remove_query_params(
- self, keys: typing.Union[str, typing.Sequence[str]]
- ) -> "URL":
- if isinstance(keys, str):
- keys = [keys]
- params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
- for key in keys:
- params.pop(key, None)
- query = urlencode(params.multi_items())
- return self.replace(query=query)
-
- def __eq__(self, other: typing.Any) -> bool:
- return str(self) == str(other)
-
- def __str__(self) -> str:
- return self._url
-
- def __repr__(self) -> str:
- url = str(self)
- if self.password:
- url = str(self.replace(password="********"))
- return f"{self.__class__.__name__}({repr(url)})"
-
-
- class URLPath(str):
- """
- A URL path string that may also hold an associated protocol and/or host.
- Used by the routing to return `url_path_for` matches.
- """
-
- def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
- assert protocol in ("http", "websocket", "")
- return str.__new__(cls, path) # type: ignore
-
- def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
- self.protocol = protocol
- self.host = host
-
- def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
- if isinstance(base_url, str):
- base_url = URL(base_url)
- if self.protocol:
- scheme = {
- "http": {True: "https", False: "http"},
- "websocket": {True: "wss", False: "ws"},
- }[self.protocol][base_url.is_secure]
- else:
- scheme = base_url.scheme
-
- netloc = self.host or base_url.netloc
- path = base_url.path.rstrip("/") + str(self)
- return str(URL(scheme=scheme, netloc=netloc, path=path))
-
-
- class Secret:
- """
- Holds a string value that should not be revealed in tracebacks etc.
- You should cast the value to `str` at the point it is required.
- """
-
- def __init__(self, value: str):
- self._value = value
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- return f"{class_name}('**********')"
-
- def __str__(self) -> str:
- return self._value
-
-
- class CommaSeparatedStrings(Sequence):
- def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
- if isinstance(value, str):
- splitter = shlex(value, posix=True)
- splitter.whitespace = ","
- splitter.whitespace_split = True
- self._items = [item.strip() for item in splitter]
- else:
- self._items = list(value)
-
- def __len__(self) -> int:
- return len(self._items)
-
- def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
- return self._items[index]
-
- def __iter__(self) -> typing.Iterator[str]:
- return iter(self._items)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- items = [item for item in self]
- return f"{class_name}({items!r})"
-
- def __str__(self) -> str:
- return ", ".join(repr(item) for item in self)
-
-
- class ImmutableMultiDict(typing.Mapping):
- def __init__(
- self,
- *args: typing.Union[
- "ImmutableMultiDict",
- typing.Mapping,
- typing.List[typing.Tuple[typing.Any, typing.Any]],
- ],
- **kwargs: typing.Any,
- ) -> None:
- assert len(args) < 2, "Too many arguments."
-
- value = args[0] if args else []
- if kwargs:
- value = (
- ImmutableMultiDict(value).multi_items()
- + ImmutableMultiDict(kwargs).multi_items()
- )
-
- if not value:
- _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
- elif hasattr(value, "multi_items"):
- value = typing.cast(ImmutableMultiDict, value)
- _items = list(value.multi_items())
- elif hasattr(value, "items"):
- value = typing.cast(typing.Mapping, value)
- _items = list(value.items())
- else:
- value = typing.cast(
- typing.List[typing.Tuple[typing.Any, typing.Any]], value
- )
- _items = list(value)
-
- self._dict = {k: v for k, v in _items}
- self._list = _items
-
- def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
- return [item_value for item_key, item_value in self._list if item_key == key]
-
- def keys(self) -> typing.KeysView:
- return self._dict.keys()
-
- def values(self) -> typing.ValuesView:
- return self._dict.values()
-
- def items(self) -> typing.ItemsView:
- return self._dict.items()
-
- def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
- return list(self._list)
-
- def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
- if key in self._dict:
- return self._dict[key]
- return default
-
- def __getitem__(self, key: typing.Any) -> str:
- return self._dict[key]
-
- def __contains__(self, key: typing.Any) -> bool:
- return key in self._dict
-
- def __iter__(self) -> typing.Iterator[typing.Any]:
- return iter(self.keys())
-
- def __len__(self) -> int:
- return len(self._dict)
-
- def __eq__(self, other: typing.Any) -> bool:
- if not isinstance(other, self.__class__):
- return False
- return sorted(self._list) == sorted(other._list)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- items = self.multi_items()
- return f"{class_name}({items!r})"
-
-
- class MultiDict(ImmutableMultiDict):
- def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
- self.setlist(key, [value])
-
- def __delitem__(self, key: typing.Any) -> None:
- self._list = [(k, v) for k, v in self._list if k != key]
- del self._dict[key]
-
- def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
- self._list = [(k, v) for k, v in self._list if k != key]
- return self._dict.pop(key, default)
-
- def popitem(self) -> typing.Tuple:
- key, value = self._dict.popitem()
- self._list = [(k, v) for k, v in self._list if k != key]
- return key, value
-
- def poplist(self, key: typing.Any) -> typing.List:
- values = [v for k, v in self._list if k == key]
- self.pop(key)
- return values
-
- def clear(self) -> None:
- self._dict.clear()
- self._list.clear()
-
- def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
- if key not in self:
- self._dict[key] = default
- self._list.append((key, default))
-
- return self[key]
-
- def setlist(self, key: typing.Any, values: typing.List) -> None:
- if not values:
- self.pop(key, None)
- else:
- existing_items = [(k, v) for (k, v) in self._list if k != key]
- self._list = existing_items + [(key, value) for value in values]
- self._dict[key] = values[-1]
-
- def append(self, key: typing.Any, value: typing.Any) -> None:
- self._list.append((key, value))
- self._dict[key] = value
-
- def update(
- self,
- *args: typing.Union[
- "MultiDict",
- typing.Mapping,
- typing.List[typing.Tuple[typing.Any, typing.Any]],
- ],
- **kwargs: typing.Any,
- ) -> None:
- value = MultiDict(*args, **kwargs)
- existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
- self._list = existing_items + value.multi_items()
- self._dict.update(value)
-
-
- class QueryParams(ImmutableMultiDict):
- """
- An immutable multidict.
- """
-
- def __init__(
- self,
- *args: typing.Union[
- "ImmutableMultiDict",
- typing.Mapping,
- typing.List[typing.Tuple[typing.Any, typing.Any]],
- str,
- bytes,
- ],
- **kwargs: typing.Any,
- ) -> None:
- assert len(args) < 2, "Too many arguments."
-
- value = args[0] if args else []
-
- if isinstance(value, str):
- super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
- elif isinstance(value, bytes):
- super().__init__(
- parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
- )
- else:
- super().__init__(*args, **kwargs) # type: ignore
- self._list = [(str(k), str(v)) for k, v in self._list]
- self._dict = {str(k): str(v) for k, v in self._dict.items()}
-
- def __str__(self) -> str:
- return urlencode(self._list)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- query_string = str(self)
- return f"{class_name}({query_string!r})"
-
-
- class UploadFile:
- """
- An uploaded file included as part of the request data.
- """
-
- spool_max_size = 1024 * 1024
-
- def __init__(
- self, filename: str, file: typing.IO = None, content_type: str = ""
- ) -> None:
- self.filename = filename
- self.content_type = content_type
- if file is None:
- file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
- self.file = file
-
- @property
- def _in_memory(self) -> bool:
- rolled_to_disk = getattr(self.file, "_rolled", True)
- return not rolled_to_disk
-
- async def write(self, data: typing.Union[bytes, str]) -> None:
- if self._in_memory:
- self.file.write(data) # type: ignore
- else:
- await run_in_threadpool(self.file.write, data)
-
- async def read(self, size: int = -1) -> typing.Union[bytes, str]:
- if self._in_memory:
- return self.file.read(size)
- return await run_in_threadpool(self.file.read, size)
-
- async def seek(self, offset: int) -> None:
- if self._in_memory:
- self.file.seek(offset)
- else:
- await run_in_threadpool(self.file.seek, offset)
-
- async def close(self) -> None:
- if self._in_memory:
- self.file.close()
- else:
- await run_in_threadpool(self.file.close)
-
-
- class FormData(ImmutableMultiDict):
- """
- An immutable multidict, containing both file uploads and text input.
- """
-
- def __init__(
- self,
- *args: typing.Union[
- "FormData",
- typing.Mapping[str, typing.Union[str, UploadFile]],
- typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
- ],
- **kwargs: typing.Union[str, UploadFile],
- ) -> None:
- super().__init__(*args, **kwargs)
-
- async def close(self) -> None:
- for key, value in self.multi_items():
- if isinstance(value, UploadFile):
- await value.close()
-
-
- class Headers(typing.Mapping[str, str]):
- """
- An immutable, case-insensitive multidict.
- """
-
- def __init__(
- self,
- headers: typing.Mapping[str, str] = None,
- raw: typing.List[typing.Tuple[bytes, bytes]] = None,
- scope: Scope = None,
- ) -> None:
- self._list: typing.List[typing.Tuple[bytes, bytes]] = []
- if headers is not None:
- assert raw is None, 'Cannot set both "headers" and "raw".'
- assert scope is None, 'Cannot set both "headers" and "scope".'
- self._list = [
- (key.lower().encode("latin-1"), value.encode("latin-1"))
- for key, value in headers.items()
- ]
- elif raw is not None:
- assert scope is None, 'Cannot set both "raw" and "scope".'
- self._list = raw
- elif scope is not None:
- self._list = scope["headers"]
-
- @property
- def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
- return list(self._list)
-
- def keys(self) -> typing.List[str]: # type: ignore
- return [key.decode("latin-1") for key, value in self._list]
-
- def values(self) -> typing.List[str]: # type: ignore
- return [value.decode("latin-1") for key, value in self._list]
-
- def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
- return [
- (key.decode("latin-1"), value.decode("latin-1"))
- for key, value in self._list
- ]
-
- def get(self, key: str, default: typing.Any = None) -> typing.Any:
- try:
- return self[key]
- except KeyError:
- return default
-
- def getlist(self, key: str) -> typing.List[str]:
- get_header_key = key.lower().encode("latin-1")
- return [
- item_value.decode("latin-1")
- for item_key, item_value in self._list
- if item_key == get_header_key
- ]
-
- def mutablecopy(self) -> "MutableHeaders":
- return MutableHeaders(raw=self._list[:])
-
- def __getitem__(self, key: str) -> str:
- get_header_key = key.lower().encode("latin-1")
- for header_key, header_value in self._list:
- if header_key == get_header_key:
- return header_value.decode("latin-1")
- raise KeyError(key)
-
- def __contains__(self, key: typing.Any) -> bool:
- get_header_key = key.lower().encode("latin-1")
- for header_key, header_value in self._list:
- if header_key == get_header_key:
- return True
- return False
-
- def __iter__(self) -> typing.Iterator[typing.Any]:
- return iter(self.keys())
-
- def __len__(self) -> int:
- return len(self._list)
-
- def __eq__(self, other: typing.Any) -> bool:
- if not isinstance(other, Headers):
- return False
- return sorted(self._list) == sorted(other._list)
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- as_dict = dict(self.items())
- if len(as_dict) == len(self):
- return f"{class_name}({as_dict!r})"
- return f"{class_name}(raw={self.raw!r})"
-
-
- class MutableHeaders(Headers):
- def __setitem__(self, key: str, value: str) -> None:
- """
- Set the header `key` to `value`, removing any duplicate entries.
- Retains insertion order.
- """
- set_key = key.lower().encode("latin-1")
- set_value = value.encode("latin-1")
-
- found_indexes = []
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == set_key:
- found_indexes.append(idx)
-
- for idx in reversed(found_indexes[1:]):
- del self._list[idx]
-
- if found_indexes:
- idx = found_indexes[0]
- self._list[idx] = (set_key, set_value)
- else:
- self._list.append((set_key, set_value))
-
- def __delitem__(self, key: str) -> None:
- """
- Remove the header `key`.
- """
- del_key = key.lower().encode("latin-1")
-
- pop_indexes = []
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == del_key:
- pop_indexes.append(idx)
-
- for idx in reversed(pop_indexes):
- del self._list[idx]
-
- @property
- def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
- return self._list
-
- def setdefault(self, key: str, value: str) -> str:
- """
- If the header `key` does not exist, then set it to `value`.
- Returns the header value.
- """
- set_key = key.lower().encode("latin-1")
- set_value = value.encode("latin-1")
-
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == set_key:
- return item_value.decode("latin-1")
- self._list.append((set_key, set_value))
- return value
-
- def update(self, other: dict) -> None:
- for key, val in other.items():
- self[key] = val
-
- def append(self, key: str, value: str) -> None:
- """
- Append a header, preserving any duplicate entries.
- """
- append_key = key.lower().encode("latin-1")
- append_value = value.encode("latin-1")
- self._list.append((append_key, append_value))
-
- def add_vary_header(self, vary: str) -> None:
- existing = self.get("vary")
- if existing is not None:
- vary = ", ".join([existing, vary])
- self["vary"] = vary
-
-
- class State:
- """
- An object that can be used to store arbitrary state.
-
- Used for `request.state` and `app.state`.
- """
-
- def __init__(self, state: typing.Dict = None):
- if state is None:
- state = {}
- super().__setattr__("_state", state)
-
- def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
- self._state[key] = value
-
- def __getattr__(self, key: typing.Any) -> typing.Any:
- try:
- return self._state[key]
- except KeyError:
- message = "'{}' object has no attribute '{}'"
- raise AttributeError(message.format(self.__class__.__name__, key))
-
- def __delattr__(self, key: typing.Any) -> None:
- del self._state[key]
|