You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

84 lines
2.7 KiB

  1. from contextvars import ContextVar
  2. from typing import Optional
  3. import sys
  4. current_async_library_cvar = ContextVar(
  5. "current_async_library_cvar", default=None
  6. ) # type: ContextVar[Optional[str]]
  7. class AsyncLibraryNotFoundError(RuntimeError):
  8. pass
  9. def current_async_library() -> str:
  10. """Detect which async library is currently running.
  11. The following libraries are currently supported:
  12. ================ =========== ============================
  13. Library Requires Magic string
  14. ================ =========== ============================
  15. **Trio** Trio v0.6+ ``"trio"``
  16. **Curio** - ``"curio"``
  17. **asyncio** ``"asyncio"``
  18. **Trio-asyncio** v0.8.2+ ``"trio"`` or ``"asyncio"``,
  19. depending on current mode
  20. ================ =========== ============================
  21. Returns:
  22. A string like ``"trio"``.
  23. Raises:
  24. AsyncLibraryNotFoundError: if called from synchronous context,
  25. or if the current async library was not recognized.
  26. Examples:
  27. .. code-block:: python3
  28. from sniffio import current_async_library
  29. async def generic_sleep(seconds):
  30. library = current_async_library()
  31. if library == "trio":
  32. import trio
  33. await trio.sleep(seconds)
  34. elif library == "asyncio":
  35. import asyncio
  36. await asyncio.sleep(seconds)
  37. # ... and so on ...
  38. else:
  39. raise RuntimeError(f"Unsupported library {library!r}")
  40. """
  41. value = current_async_library_cvar.get()
  42. if value is not None:
  43. return value
  44. # Sniff for curio (for now)
  45. if 'curio' in sys.modules:
  46. from curio.meta import curio_running
  47. if curio_running():
  48. return 'curio'
  49. # Need to sniff for asyncio
  50. if "asyncio" in sys.modules:
  51. import asyncio
  52. try:
  53. current_task = asyncio.current_task # type: ignore[attr-defined]
  54. except AttributeError:
  55. current_task = asyncio.Task.current_task # type: ignore[attr-defined]
  56. try:
  57. if current_task() is not None:
  58. if (3, 7) <= sys.version_info:
  59. # asyncio has contextvars support, and we're in a task, so
  60. # we can safely cache the sniffed value
  61. current_async_library_cvar.set("asyncio")
  62. return "asyncio"
  63. except RuntimeError:
  64. pass
  65. raise AsyncLibraryNotFoundError(
  66. "unknown async library, or not in async context"
  67. )