Skip to content

Commit c518300

Browse files
authored
Fixed AssertionError when using nest_asyncio (#841)
This stems from the incorrect placement of `nest_asyncio.apply()`, as it should be called before `asyncio.run()`. Fixes #840.
1 parent d14f005 commit c518300

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

docs/versionhistory.rst

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
99
thread waker socket pair. This should improve the performance of ``wait_readable()``
1010
and ``wait_writable()`` when using the ``ProactorEventLoop``
1111
(`#836 <https://github.com/agronholm/anyio/pull/836>`_; PR by @graingert)
12+
- Fixed ``AssertionError`` when using ``nest-asyncio``
13+
(`#840 <https://github.com/agronholm/anyio/issues/840>`_)
1214

1315
**4.7.0**
1416

src/anyio/_backends/_asyncio.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -677,40 +677,42 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
677677
self.cancel_scope = cancel_scope
678678

679679

680-
class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
680+
class TaskStateStore(
681+
MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState]
682+
):
681683
def __init__(self) -> None:
682684
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
683-
self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
685+
self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {}
684686

685-
def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
686-
assert isinstance(key, asyncio.Task)
687+
def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState:
688+
task = cast(asyncio.Task, key)
687689
try:
688-
return self._task_states[key]
690+
return self._task_states[task]
689691
except KeyError:
690-
if coro := key.get_coro():
692+
if coro := task.get_coro():
691693
if state := self._preliminary_task_states.get(coro):
692694
return state
693695

694696
raise KeyError(key)
695697

696698
def __setitem__(
697-
self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
699+
self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, /
698700
) -> None:
699-
if isinstance(key, asyncio.Task):
700-
self._task_states[key] = value
701-
else:
701+
if isinstance(key, Coroutine):
702702
self._preliminary_task_states[key] = value
703-
704-
def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
705-
if isinstance(key, asyncio.Task):
706-
del self._task_states[key]
707703
else:
704+
self._task_states[key] = value
705+
706+
def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None:
707+
if isinstance(key, Coroutine):
708708
del self._preliminary_task_states[key]
709+
else:
710+
del self._task_states[key]
709711

710712
def __len__(self) -> int:
711713
return len(self._task_states) + len(self._preliminary_task_states)
712714

713-
def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
715+
def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]:
714716
yield from self._task_states
715717
yield from self._preliminary_task_states
716718

tests/test_taskgroups.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313
from exceptiongroup import catch
14-
from pytest import FixtureRequest
14+
from pytest import FixtureRequest, MonkeyPatch
1515
from pytest_mock import MockerFixture
1616

1717
import anyio
@@ -1778,3 +1778,14 @@ async def sync_coro() -> None:
17781778
async with create_task_group() as tg:
17791779
tg.start_soon(sync_coro)
17801780
tg.cancel_scope.cancel()
1781+
1782+
1783+
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
1784+
async def test_patched_asyncio_task(monkeypatch: MonkeyPatch) -> None:
1785+
monkeypatch.setattr(
1786+
asyncio,
1787+
"Task",
1788+
asyncio.tasks._PyTask, # type: ignore[attr-defined]
1789+
)
1790+
async with create_task_group() as tg:
1791+
tg.start_soon(sleep, 0)

0 commit comments

Comments
 (0)