|
- from contextlib import asynccontextmanager as asynccontextmanager
- from typing import AsyncGenerator, ContextManager, TypeVar
-
- import anyio.to_thread
- from anyio import CapacityLimiter
- from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
- from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
- from starlette.concurrency import ( # noqa
- run_until_first_complete as run_until_first_complete,
- )
-
- _T = TypeVar("_T")
-
-
- @asynccontextmanager
- async def contextmanager_in_threadpool(
- cm: ContextManager[_T],
- ) -> AsyncGenerator[_T, None]:
- # blocking __exit__ from running waiting on a free thread
- # can create race conditions/deadlocks if the context manager itself
- # has its own internal pool (e.g. a database connection pool)
- # to avoid this we let __exit__ run without a capacity limit
- # since we're creating a new limiter for each call, any non-zero limit
- # works (1 is arbitrary)
- exit_limiter = CapacityLimiter(1)
- try:
- yield await run_in_threadpool(cm.__enter__)
- except Exception as e:
- ok = bool(
- await anyio.to_thread.run_sync(
- cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
- )
- )
- if not ok:
- raise e
- else:
- await anyio.to_thread.run_sync(
- cm.__exit__, None, None, None, limiter=exit_limiter
- )
|