You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

153 lines
5.4 KiB

  1. from contextlib import contextmanager
  2. from inspect import isasyncgenfunction, iscoroutinefunction
  3. from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, cast
  4. import pytest
  5. import sniffio
  6. from ._core._eventloop import get_all_backends, get_asynclib
  7. from .abc import TestRunner
  8. if TYPE_CHECKING:
  9. from _pytest.config import Config
  10. _current_runner: Optional[TestRunner] = None
  11. def extract_backend_and_options(backend: object) -> Tuple[str, Dict[str, Any]]:
  12. if isinstance(backend, str):
  13. return backend, {}
  14. elif isinstance(backend, tuple) and len(backend) == 2:
  15. if isinstance(backend[0], str) and isinstance(backend[1], dict):
  16. return cast(Tuple[str, Dict[str, Any]], backend)
  17. raise TypeError('anyio_backend must be either a string or tuple of (string, dict)')
  18. @contextmanager
  19. def get_runner(backend_name: str, backend_options: Dict[str, Any]) -> Iterator[TestRunner]:
  20. global _current_runner
  21. if _current_runner:
  22. yield _current_runner
  23. return
  24. asynclib = get_asynclib(backend_name)
  25. token = None
  26. if sniffio.current_async_library_cvar.get(None) is None:
  27. # Since we're in control of the event loop, we can cache the name of the async library
  28. token = sniffio.current_async_library_cvar.set(backend_name)
  29. try:
  30. backend_options = backend_options or {}
  31. with asynclib.TestRunner(**backend_options) as runner:
  32. _current_runner = runner
  33. yield runner
  34. finally:
  35. _current_runner = None
  36. if token:
  37. sniffio.current_async_library_cvar.reset(token)
  38. def pytest_configure(config: "Config") -> None:
  39. config.addinivalue_line('markers', 'anyio: mark the (coroutine function) test to be run '
  40. 'asynchronously via anyio.')
  41. def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
  42. def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
  43. backend_name, backend_options = extract_backend_and_options(anyio_backend)
  44. if has_backend_arg:
  45. kwargs['anyio_backend'] = anyio_backend
  46. with get_runner(backend_name, backend_options) as runner:
  47. if isasyncgenfunction(func):
  48. gen = func(*args, **kwargs)
  49. try:
  50. value = runner.call(gen.asend, None)
  51. except StopAsyncIteration:
  52. raise RuntimeError('Async generator did not yield')
  53. yield value
  54. try:
  55. runner.call(gen.asend, None)
  56. except StopAsyncIteration:
  57. pass
  58. else:
  59. runner.call(gen.aclose)
  60. raise RuntimeError('Async generator fixture did not stop')
  61. else:
  62. yield runner.call(func, *args, **kwargs)
  63. # Only apply this to coroutine functions and async generator functions in requests that involve
  64. # the anyio_backend fixture
  65. func = fixturedef.func
  66. if isasyncgenfunction(func) or iscoroutinefunction(func):
  67. if 'anyio_backend' in request.fixturenames:
  68. has_backend_arg = 'anyio_backend' in fixturedef.argnames
  69. fixturedef.func = wrapper
  70. if not has_backend_arg:
  71. fixturedef.argnames += ('anyio_backend',)
  72. @pytest.hookimpl(tryfirst=True)
  73. def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:
  74. if collector.istestfunction(obj, name):
  75. inner_func = obj.hypothesis.inner_test if hasattr(obj, 'hypothesis') else obj
  76. if iscoroutinefunction(inner_func):
  77. marker = collector.get_closest_marker('anyio')
  78. own_markers = getattr(obj, 'pytestmark', ())
  79. if marker or any(marker.name == 'anyio' for marker in own_markers):
  80. pytest.mark.usefixtures('anyio_backend')(obj)
  81. @pytest.hookimpl(tryfirst=True)
  82. def pytest_pyfunc_call(pyfuncitem: Any) -> Optional[bool]:
  83. def run_with_hypothesis(**kwargs: Any) -> None:
  84. with get_runner(backend_name, backend_options) as runner:
  85. runner.call(original_func, **kwargs)
  86. backend = pyfuncitem.funcargs.get('anyio_backend')
  87. if backend:
  88. backend_name, backend_options = extract_backend_and_options(backend)
  89. if hasattr(pyfuncitem.obj, 'hypothesis'):
  90. # Wrap the inner test function unless it's already wrapped
  91. original_func = pyfuncitem.obj.hypothesis.inner_test
  92. if original_func.__qualname__ != run_with_hypothesis.__qualname__:
  93. if iscoroutinefunction(original_func):
  94. pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
  95. return None
  96. if iscoroutinefunction(pyfuncitem.obj):
  97. funcargs = pyfuncitem.funcargs
  98. testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
  99. with get_runner(backend_name, backend_options) as runner:
  100. runner.call(pyfuncitem.obj, **testargs)
  101. return True
  102. return None
  103. @pytest.fixture(params=get_all_backends())
  104. def anyio_backend(request: Any) -> Any:
  105. return request.param
  106. @pytest.fixture
  107. def anyio_backend_name(anyio_backend: Any) -> str:
  108. if isinstance(anyio_backend, str):
  109. return anyio_backend
  110. else:
  111. return anyio_backend[0]
  112. @pytest.fixture
  113. def anyio_backend_options(anyio_backend: Any) -> Dict[str, Any]:
  114. if isinstance(anyio_backend, str):
  115. return {}
  116. else:
  117. return anyio_backend[1]