|
- from __future__ import annotations
-
- import functools
- import inspect
- import sys
- from collections.abc import Awaitable, Generator
- from contextlib import AbstractAsyncContextManager, contextmanager
- from typing import Any, Callable, Generic, Protocol, TypeVar, overload
-
- from starlette.types import Scope
-
- if sys.version_info >= (3, 13): # pragma: no cover
- from typing import TypeIs
- else: # pragma: no cover
- from typing_extensions import TypeIs
-
- has_exceptiongroups = True
- if sys.version_info < (3, 11): # pragma: no cover
- try:
- from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
- except ImportError:
- has_exceptiongroups = False
-
- T = TypeVar("T")
- AwaitableCallable = Callable[..., Awaitable[T]]
-
-
- @overload
- def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
-
-
- @overload
- def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
-
-
- def is_async_callable(obj: Any) -> Any:
- while isinstance(obj, functools.partial):
- obj = obj.func
-
- return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
-
-
- T_co = TypeVar("T_co", covariant=True)
-
-
- class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
-
-
- class SupportsAsyncClose(Protocol):
- async def close(self) -> None: ... # pragma: no cover
-
-
- SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
-
-
- class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
- __slots__ = ("aw", "entered")
-
- def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
- self.aw = aw
-
- def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
- return self.aw.__await__()
-
- async def __aenter__(self) -> SupportsAsyncCloseType:
- self.entered = await self.aw
- return self.entered
-
- async def __aexit__(self, *args: Any) -> None | bool:
- await self.entered.close()
- return None
-
-
- @contextmanager
- def collapse_excgroups() -> Generator[None, None, None]:
- try:
- yield
- except BaseException as exc:
- if has_exceptiongroups: # pragma: no cover
- while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
- exc = exc.exceptions[0]
-
- raise exc
-
-
- def get_route_path(scope: Scope) -> str:
- path: str = scope["path"]
- root_path = scope.get("root_path", "")
- if not root_path:
- return path
-
- if not path.startswith(root_path):
- return path
-
- if path == root_path:
- return ""
-
- if path[len(root_path)] == "/":
- return path[len(root_path) :]
-
- return path
|