Skip to content

Commit 0f80611

Browse files
authored
Added support for wait_readable() and wait_writable() on ProactorEventLoop (#831)
1 parent 97d5fe6 commit 0f80611

File tree

5 files changed

+210
-65
lines changed

5 files changed

+210
-65
lines changed

docs/versionhistory.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
1212
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
1313
an object with a ``.fileno()`` method or an integer handle, and deprecated
1414
their now obsolete versions (``wait_socket_readable()`` and
15-
``wait_socket_writable()`` (PR by @davidbrochart)
15+
``wait_socket_writable()``) (PR by @davidbrochart)
16+
- Added support for ``wait_readable()`` and ``wait_writable()`` on ``ProactorEventLoop``
17+
(used on asyncio + Windows by default)
1618
- Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the
1719
``anyio.AsyncFile`` class
1820
(`#825 <https://github.com/agronholm/anyio/issues/825>`_)

src/anyio/_backends/_asyncio.py

+37-25
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@
103103
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
104104

105105
if TYPE_CHECKING:
106-
from _typeshed import HasFileno
106+
from _typeshed import FileDescriptorLike
107+
else:
108+
FileDescriptorLike = object
107109

108110
if sys.version_info >= (3, 10):
109111
from typing import ParamSpec
@@ -2734,7 +2736,7 @@ async def getnameinfo(
27342736
return await get_running_loop().getnameinfo(sockaddr, flags)
27352737

27362738
@classmethod
2737-
async def wait_readable(cls, obj: HasFileno | int) -> None:
2739+
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
27382740
await cls.checkpoint()
27392741
try:
27402742
read_events = _read_events.get()
@@ -2746,25 +2748,30 @@ async def wait_readable(cls, obj: HasFileno | int) -> None:
27462748
obj = obj.fileno()
27472749

27482750
if read_events.get(obj):
2749-
raise BusyResourceError("reading from") from None
2751+
raise BusyResourceError("reading from")
27502752

27512753
loop = get_running_loop()
2752-
event = read_events[obj] = asyncio.Event()
2753-
loop.add_reader(obj, event.set)
2754+
event = asyncio.Event()
2755+
try:
2756+
loop.add_reader(obj, event.set)
2757+
except NotImplementedError:
2758+
from anyio._core._asyncio_selector_thread import get_selector
2759+
2760+
selector = get_selector()
2761+
selector.add_reader(obj, event.set)
2762+
remove_reader = selector.remove_reader
2763+
else:
2764+
remove_reader = loop.remove_reader
2765+
2766+
read_events[obj] = event
27542767
try:
27552768
await event.wait()
27562769
finally:
2757-
if read_events.pop(obj, None) is not None:
2758-
loop.remove_reader(obj)
2759-
readable = True
2760-
else:
2761-
readable = False
2762-
2763-
if not readable:
2764-
raise ClosedResourceError
2770+
remove_reader(obj)
2771+
del read_events[obj]
27652772

27662773
@classmethod
2767-
async def wait_writable(cls, obj: HasFileno | int) -> None:
2774+
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
27682775
await cls.checkpoint()
27692776
try:
27702777
write_events = _write_events.get()
@@ -2776,22 +2783,27 @@ async def wait_writable(cls, obj: HasFileno | int) -> None:
27762783
obj = obj.fileno()
27772784

27782785
if write_events.get(obj):
2779-
raise BusyResourceError("writing to") from None
2786+
raise BusyResourceError("writing to")
27802787

27812788
loop = get_running_loop()
2782-
event = write_events[obj] = asyncio.Event()
2783-
loop.add_writer(obj, event.set)
2789+
event = asyncio.Event()
2790+
try:
2791+
loop.add_writer(obj, event.set)
2792+
except NotImplementedError:
2793+
from anyio._core._asyncio_selector_thread import get_selector
2794+
2795+
selector = get_selector()
2796+
selector.add_writer(obj, event.set)
2797+
remove_writer = selector.remove_writer
2798+
else:
2799+
remove_writer = loop.remove_writer
2800+
2801+
write_events[obj] = event
27842802
try:
27852803
await event.wait()
27862804
finally:
2787-
if write_events.pop(obj, None) is not None:
2788-
loop.remove_writer(obj)
2789-
writable = True
2790-
else:
2791-
writable = False
2792-
2793-
if not writable:
2794-
raise ClosedResourceError
2805+
del write_events[obj]
2806+
remove_writer(obj)
27952807

27962808
@classmethod
27972809
def current_default_thread_limiter(cls) -> CapacityLimiter:
+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import socket
5+
import threading
6+
from collections.abc import Callable
7+
from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector
8+
from typing import TYPE_CHECKING, Any
9+
10+
if TYPE_CHECKING:
11+
from _typeshed import FileDescriptorLike
12+
13+
_selector_lock = threading.Lock()
14+
_selector: Selector | None = None
15+
16+
17+
class Selector:
18+
def __init__(self) -> None:
19+
self._thread = threading.Thread(target=self.run, name="AnyIO socket selector")
20+
self._selector = DefaultSelector()
21+
self._send, self._receive = socket.socketpair()
22+
self._send.setblocking(False)
23+
self._receive.setblocking(False)
24+
self._selector.register(self._receive, EVENT_READ)
25+
self._closed = False
26+
27+
def start(self) -> None:
28+
self._thread.start()
29+
threading._register_atexit(self._stop) # type: ignore[attr-defined]
30+
31+
def _stop(self) -> None:
32+
global _selector
33+
self._closed = True
34+
self._notify_self()
35+
self._send.close()
36+
self._thread.join()
37+
self._selector.unregister(self._receive)
38+
self._receive.close()
39+
self._selector.close()
40+
_selector = None
41+
assert (
42+
not self._selector.get_map()
43+
), "selector still has registered file descriptors after shutdown"
44+
45+
def _notify_self(self) -> None:
46+
try:
47+
self._send.send(b"\x00")
48+
except BlockingIOError:
49+
pass
50+
51+
def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
52+
loop = asyncio.get_running_loop()
53+
try:
54+
key = self._selector.get_key(fd)
55+
except KeyError:
56+
self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)})
57+
else:
58+
if EVENT_READ in key.data:
59+
raise ValueError(
60+
"this file descriptor is already registered for reading"
61+
)
62+
63+
key.data[EVENT_READ] = loop, callback
64+
self._selector.modify(fd, key.events | EVENT_READ, key.data)
65+
66+
self._notify_self()
67+
68+
def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
69+
loop = asyncio.get_running_loop()
70+
try:
71+
key = self._selector.get_key(fd)
72+
except KeyError:
73+
self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)})
74+
else:
75+
if EVENT_WRITE in key.data:
76+
raise ValueError(
77+
"this file descriptor is already registered for writing"
78+
)
79+
80+
key.data[EVENT_WRITE] = loop, callback
81+
self._selector.modify(fd, key.events | EVENT_WRITE, key.data)
82+
83+
self._notify_self()
84+
85+
def remove_reader(self, fd: FileDescriptorLike) -> bool:
86+
try:
87+
key = self._selector.get_key(fd)
88+
except KeyError:
89+
return False
90+
91+
if new_events := key.events ^ EVENT_READ:
92+
del key.data[EVENT_READ]
93+
self._selector.modify(fd, new_events, key.data)
94+
else:
95+
self._selector.unregister(fd)
96+
97+
return True
98+
99+
def remove_writer(self, fd: FileDescriptorLike) -> bool:
100+
try:
101+
key = self._selector.get_key(fd)
102+
except KeyError:
103+
return False
104+
105+
if new_events := key.events ^ EVENT_WRITE:
106+
del key.data[EVENT_WRITE]
107+
self._selector.modify(fd, new_events, key.data)
108+
else:
109+
self._selector.unregister(fd)
110+
111+
return True
112+
113+
def run(self) -> None:
114+
while not self._closed:
115+
for key, events in self._selector.select():
116+
if key.fileobj is self._receive:
117+
try:
118+
while self._receive.recv(4096):
119+
pass
120+
except BlockingIOError:
121+
pass
122+
123+
continue
124+
125+
if events & EVENT_READ:
126+
loop, callback = key.data[EVENT_READ]
127+
self.remove_reader(key.fd)
128+
try:
129+
loop.call_soon_threadsafe(callback)
130+
except RuntimeError:
131+
pass # the loop was already closed
132+
133+
if events & EVENT_WRITE:
134+
loop, callback = key.data[EVENT_WRITE]
135+
self.remove_writer(key.fd)
136+
try:
137+
loop.call_soon_threadsafe(callback)
138+
except RuntimeError:
139+
pass # the loop was already closed
140+
141+
142+
def get_selector() -> Selector:
143+
global _selector
144+
145+
with _selector_lock:
146+
if _selector is None:
147+
_selector = Selector()
148+
_selector.start()
149+
150+
return _selector

src/anyio/_core/_sockets.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from ._tasks import create_task_group, move_on_after
3333

3434
if TYPE_CHECKING:
35-
from _typeshed import HasFileno
35+
from _typeshed import FileDescriptorLike
3636
else:
37-
HasFileno = object
37+
FileDescriptorLike = object
3838

3939
if sys.version_info < (3, 11):
4040
from exceptiongroup import ExceptionGroup
@@ -609,9 +609,6 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
609609
610610
Wait until the given socket has data to be read.
611611
612-
This does **NOT** work on Windows when using the asyncio backend with a proactor
613-
event loop (default on py3.8+).
614-
615612
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
616613
level constructs like socket streams!
617614
@@ -649,7 +646,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
649646
return get_async_backend().wait_writable(sock.fileno())
650647

651648

652-
def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
649+
def wait_readable(obj: FileDescriptorLike) -> Awaitable[None]:
653650
"""
654651
Wait until the given object has data to be read.
655652
@@ -663,10 +660,11 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
663660
descriptors aren't supported, and neither are handles that refer to anything besides
664661
a ``SOCKET``.
665662
666-
This does **NOT** work on Windows when using the asyncio backend with a proactor
667-
event loop (default on py3.8+).
663+
On backends where this functionality is not natively provided (asyncio
664+
``ProactorEventLoop`` on Windows), it is provided using a separate selector thread
665+
which is set to shut down when the interpreter shuts down.
668666
669-
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
667+
.. warning:: Don't use this on raw sockets that have been wrapped by any higher
670668
level constructs like socket streams!
671669
672670
:param obj: an object with a ``.fileno()`` method or an integer handle
@@ -679,25 +677,22 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
679677
return get_async_backend().wait_readable(obj)
680678

681679

682-
def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
680+
def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
683681
"""
684682
Wait until the given object can be written to.
685683
686-
This does **NOT** work on Windows when using the asyncio backend with a proactor
687-
event loop (default on py3.8+).
688-
689-
.. seealso:: See the documentation of :func:`wait_readable` for the definition of
690-
``obj``.
691-
692-
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
693-
level constructs like socket streams!
694-
695684
:param obj: an object with a ``.fileno()`` method or an integer handle
696685
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
697686
object to become writable
698687
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
699688
to become writable
700689
690+
.. seealso:: See the documentation of :func:`wait_readable` for the definition of
691+
``obj`` and notes on backend compatibility.
692+
693+
.. warning:: Don't use this on raw sockets that have been wrapped by any higher
694+
level constructs like socket streams!
695+
701696
"""
702697
return get_async_backend().wait_writable(obj)
703698

tests/test_sockets.py

+6-20
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from exceptiongroup import ExceptionGroup
6666

6767
if TYPE_CHECKING:
68-
from _typeshed import HasFileno
68+
from _typeshed import FileDescriptorLike
6969

7070
AnyIPAddressFamily = Literal[
7171
AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
@@ -1858,16 +1858,7 @@ async def test_connect_tcp_getaddrinfo_context() -> None:
18581858

18591859
@pytest.mark.parametrize("socket_type", ["socket", "fd"])
18601860
@pytest.mark.parametrize("event", ["readable", "writable"])
1861-
async def test_wait_socket(
1862-
anyio_backend_name: str, event: str, socket_type: str
1863-
) -> None:
1864-
if anyio_backend_name == "asyncio" and platform.system() == "Windows":
1865-
import asyncio
1866-
1867-
policy = asyncio.get_event_loop_policy()
1868-
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
1869-
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")
1870-
1861+
async def test_wait_socket(event: str, socket_type: str) -> None:
18711862
wait = wait_readable if event == "readable" else wait_writable
18721863

18731864
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
@@ -1880,20 +1871,15 @@ async def test_wait_socket(
18801871

18811872
conn, addr = server_sock.accept()
18821873
with conn:
1883-
sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn
1884-
with fail_after(10):
1874+
sock_or_fd: FileDescriptorLike = (
1875+
conn.fileno() if socket_type == "fd" else conn
1876+
)
1877+
with fail_after(3):
18851878
await wait(sock_or_fd)
18861879
assert conn.recv(1024) == b"Hello, world"
18871880

18881881

18891882
async def test_deprecated_wait_socket(anyio_backend_name: str) -> None:
1890-
if anyio_backend_name == "asyncio" and platform.system() == "Windows":
1891-
import asyncio
1892-
1893-
policy = asyncio.get_event_loop_policy()
1894-
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
1895-
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")
1896-
18971883
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
18981884
with pytest.warns(
18991885
DeprecationWarning,

0 commit comments

Comments
 (0)