@@ -679,15 +679,16 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
679
679
680
680
class TaskStateStore (MutableMapping ["Awaitable[Any] | asyncio.Task" , TaskState ]):
681
681
def __init__ (self ) -> None :
682
- self ._task_states = WeakKeyDictionary [asyncio .Task , TaskState ]()
682
+ self ._task_states = WeakKeyDictionary [
683
+ "asyncio.Task | Awaitable[Any]" , TaskState
684
+ ]()
683
685
self ._preliminary_task_states : dict [Awaitable [Any ], TaskState ] = {}
684
686
685
687
def __getitem__ (self , key : Awaitable [Any ] | asyncio .Task , / ) -> TaskState :
686
- assert isinstance (key , asyncio .Task )
687
688
try :
688
689
return self ._task_states [key ]
689
690
except KeyError :
690
- if coro := key .get_coro ():
691
+ if coro := cast ( asyncio . Task , key ) .get_coro ():
691
692
if state := self ._preliminary_task_states .get (coro ):
692
693
return state
693
694
@@ -696,16 +697,16 @@ def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
696
697
def __setitem__ (
697
698
self , key : asyncio .Task | Awaitable [Any ], value : TaskState , /
698
699
) -> None :
699
- if isinstance (key , asyncio .Task ):
700
- self ._task_states [key ] = value
701
- else :
700
+ if isinstance (key , Coroutine ):
702
701
self ._preliminary_task_states [key ] = value
702
+ else :
703
+ self ._task_states [key ] = value
703
704
704
705
def __delitem__ (self , key : asyncio .Task | Awaitable [Any ], / ) -> None :
705
- if isinstance (key , asyncio .Task ):
706
- del self ._task_states [key ]
707
- else :
706
+ if isinstance (key , Coroutine ):
708
707
del self ._preliminary_task_states [key ]
708
+ else :
709
+ del self ._task_states [key ]
709
710
710
711
def __len__ (self ) -> int :
711
712
return len (self ._task_states ) + len (self ._preliminary_task_states )
0 commit comments