You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

102 line
2.7 KiB

  1. from __future__ import annotations
  2. import functools
  3. import inspect
  4. import sys
  5. from collections.abc import Awaitable, Generator
  6. from contextlib import AbstractAsyncContextManager, contextmanager
  7. from typing import Any, Callable, Generic, Protocol, TypeVar, overload
  8. from starlette.types import Scope
  9. if sys.version_info >= (3, 13): # pragma: no cover
  10. from typing import TypeIs
  11. else: # pragma: no cover
  12. from typing_extensions import TypeIs
  13. has_exceptiongroups = True
  14. if sys.version_info < (3, 11): # pragma: no cover
  15. try:
  16. from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
  17. except ImportError:
  18. has_exceptiongroups = False
  19. T = TypeVar("T")
  20. AwaitableCallable = Callable[..., Awaitable[T]]
  21. @overload
  22. def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
  23. @overload
  24. def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
  25. def is_async_callable(obj: Any) -> Any:
  26. while isinstance(obj, functools.partial):
  27. obj = obj.func
  28. return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
  29. T_co = TypeVar("T_co", covariant=True)
  30. class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
  31. class SupportsAsyncClose(Protocol):
  32. async def close(self) -> None: ... # pragma: no cover
  33. SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
  34. class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
  35. __slots__ = ("aw", "entered")
  36. def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
  37. self.aw = aw
  38. def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
  39. return self.aw.__await__()
  40. async def __aenter__(self) -> SupportsAsyncCloseType:
  41. self.entered = await self.aw
  42. return self.entered
  43. async def __aexit__(self, *args: Any) -> None | bool:
  44. await self.entered.close()
  45. return None
  46. @contextmanager
  47. def collapse_excgroups() -> Generator[None, None, None]:
  48. try:
  49. yield
  50. except BaseException as exc:
  51. if has_exceptiongroups: # pragma: no cover
  52. while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
  53. exc = exc.exceptions[0]
  54. raise exc
  55. def get_route_path(scope: Scope) -> str:
  56. path: str = scope["path"]
  57. root_path = scope.get("root_path", "")
  58. if not root_path:
  59. return path
  60. if not path.startswith(root_path):
  61. return path
  62. if path == root_path:
  63. return ""
  64. if path[len(root_path)] == "/":
  65. return path[len(root_path) :]
  66. return path