You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.4 KiB
40 lines
1.4 KiB
from contextlib import AsyncExitStack as AsyncExitStack # noqa |
|
from contextlib import asynccontextmanager as asynccontextmanager |
|
from typing import AsyncGenerator, ContextManager, TypeVar |
|
|
|
import anyio |
|
from anyio import CapacityLimiter |
|
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa |
|
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa |
|
from starlette.concurrency import ( # noqa |
|
run_until_first_complete as run_until_first_complete, |
|
) |
|
|
|
_T = TypeVar("_T") |
|
|
|
|
|
@asynccontextmanager |
|
async def contextmanager_in_threadpool( |
|
cm: ContextManager[_T], |
|
) -> AsyncGenerator[_T, None]: |
|
# blocking __exit__ from running waiting on a free thread |
|
# can create race conditions/deadlocks if the context manager itself |
|
# has its own internal pool (e.g. a database connection pool) |
|
# to avoid this we let __exit__ run without a capacity limit |
|
# since we're creating a new limiter for each call, any non-zero limit |
|
# works (1 is arbitrary) |
|
exit_limiter = CapacityLimiter(1) |
|
try: |
|
yield await run_in_threadpool(cm.__enter__) |
|
except Exception as e: |
|
ok = bool( |
|
await anyio.to_thread.run_sync( |
|
cm.__exit__, type(e), e, None, limiter=exit_limiter |
|
) |
|
) |
|
if not ok: |
|
raise e |
|
else: |
|
await anyio.to_thread.run_sync( |
|
cm.__exit__, None, None, None, limiter=exit_limiter |
|
)
|
|
|