@@ -677,40 +677,42 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
677
677
self .cancel_scope = cancel_scope
678
678
679
679
680
- class TaskStateStore (MutableMapping ["Awaitable[Any] | asyncio.Task" , TaskState ]):
680
+ class TaskStateStore (
681
+ MutableMapping ["Coroutine[Any, Any, Any] | asyncio.Task" , TaskState ]
682
+ ):
681
683
def __init__ (self ) -> None :
682
684
self ._task_states = WeakKeyDictionary [asyncio .Task , TaskState ]()
683
- self ._preliminary_task_states : dict [Awaitable [ Any ], TaskState ] = {}
685
+ self ._preliminary_task_states : dict [Coroutine [ Any , Any , Any ], TaskState ] = {}
684
686
685
- def __getitem__ (self , key : Awaitable [ Any ] | asyncio .Task , / ) -> TaskState :
686
- assert isinstance ( key , asyncio .Task )
687
+ def __getitem__ (self , key : Coroutine [ Any , Any , Any ] | asyncio .Task , / ) -> TaskState :
688
+ task = cast ( asyncio .Task , key )
687
689
try :
688
- return self ._task_states [key ]
690
+ return self ._task_states [task ]
689
691
except KeyError :
690
- if coro := key .get_coro ():
692
+ if coro := task .get_coro ():
691
693
if state := self ._preliminary_task_states .get (coro ):
692
694
return state
693
695
694
696
raise KeyError (key )
695
697
696
698
def __setitem__ (
697
- self , key : asyncio .Task | Awaitable [ Any ], value : TaskState , /
699
+ self , key : asyncio .Task | Coroutine [ Any , Any , Any ], value : TaskState , /
698
700
) -> None :
699
- if isinstance (key , asyncio .Task ):
700
- self ._task_states [key ] = value
701
- else :
701
+ if isinstance (key , Coroutine ):
702
702
self ._preliminary_task_states [key ] = value
703
-
704
- def __delitem__ (self , key : asyncio .Task | Awaitable [Any ], / ) -> None :
705
- if isinstance (key , asyncio .Task ):
706
- del self ._task_states [key ]
707
703
else :
704
+ self ._task_states [key ] = value
705
+
706
+ def __delitem__ (self , key : asyncio .Task | Coroutine [Any , Any , Any ], / ) -> None :
707
+ if isinstance (key , Coroutine ):
708
708
del self ._preliminary_task_states [key ]
709
+ else :
710
+ del self ._task_states [key ]
709
711
710
712
def __len__ (self ) -> int :
711
713
return len (self ._task_states ) + len (self ._preliminary_task_states )
712
714
713
- def __iter__ (self ) -> Iterator [Awaitable [ Any ] | asyncio .Task ]:
715
+ def __iter__ (self ) -> Iterator [Coroutine [ Any , Any , Any ] | asyncio .Task ]:
714
716
yield from self ._task_states
715
717
yield from self ._preliminary_task_states
716
718
0 commit comments