Skip to content

Commit ee40dbc

Browse files
authored
Notify ChannelQueue that the response router thread is finishing (jupyter-server#896)
1 parent 4f1e09e commit ee40dbc

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

jupyter_server/gateway/managers.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,14 @@ def cleanup_resources(self, restart=False):
498498
class ChannelQueue(Queue):
499499

500500
channel_name: Optional[str] = None
501+
response_router_finished: bool
501502

502503
def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger):
503504
super().__init__()
504505
self.channel_name = channel_name
505506
self.channel_socket = channel_socket
506507
self.log = log
508+
self.response_router_finished = False
507509

508510
async def _async_get(self, timeout=None):
509511
if timeout is None:
@@ -516,6 +518,8 @@ async def _async_get(self, timeout=None):
516518
try:
517519
return self.get(block=False)
518520
except Empty:
521+
if self.response_router_finished:
522+
raise RuntimeError("Response router had finished")
519523
if monotonic() > end_time:
520524
raise
521525
await asyncio.sleep(0)
@@ -598,16 +602,16 @@ class GatewayKernelClient(AsyncKernelClient):
598602
# flag for whether execute requests should be allowed to call raw_input:
599603
allow_stdin = False
600604
_channels_stopped: bool
601-
_channel_queues: Optional[dict]
605+
_channel_queues: Optional[Dict[str, ChannelQueue]]
602606
_control_channel: Optional[ChannelQueue]
603607
_hb_channel: Optional[ChannelQueue]
604608
_stdin_channel: Optional[ChannelQueue]
605609
_iopub_channel: Optional[ChannelQueue]
606610
_shell_channel: Optional[ChannelQueue]
607611

608-
def __init__(self, **kwargs):
612+
def __init__(self, kernel_id, **kwargs):
609613
super().__init__(**kwargs)
610-
self.kernel_id = kwargs["kernel_id"]
614+
self.kernel_id = kernel_id
611615
self.channel_socket: Optional[websocket.WebSocket] = None
612616
self.response_router: Optional[Thread] = None
613617
self._channels_stopped = False
@@ -644,13 +648,14 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
644648
enable_multithread=True,
645649
sslopt=ssl_options,
646650
)
647-
self.response_router = Thread(target=self._route_responses)
648-
self.response_router.start()
649651

650652
await ensure_async(
651653
super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)
652654
)
653655

656+
self.response_router = Thread(target=self._route_responses)
657+
self.response_router.start()
658+
654659
def stop_channels(self):
655660
"""Stops all the running channels for this kernel.
656661
@@ -753,6 +758,11 @@ def _route_responses(self):
753758
if not self._channels_stopped:
754759
self.log.warning(f"Unexpected exception encountered ({be})")
755760

761+
# Notify channel queues that this thread had finished and no more messages are being received
762+
assert self._channel_queues is not None
763+
for channel_queue in self._channel_queues.values():
764+
channel_queue.response_router_finished = True
765+
756766
self.log.debug("Response router thread exiting...")
757767

758768

tests/test_gateway.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from tornado.httpclient import HTTPRequest, HTTPResponse
1515
from tornado.web import HTTPError
1616

17-
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient
17+
from jupyter_server.gateway.managers import (
18+
ChannelQueue,
19+
GatewayClient,
20+
GatewayKernelManager,
21+
)
1822
from jupyter_server.utils import ensure_async
1923

2024
from .utils import expected_http_error
@@ -164,6 +168,15 @@ async def mock_gateway_request(url, **kwargs):
164168
mock_http_user = "alice"
165169

166170

171+
def mock_websocket_create_connection(recv_side_effect=None):
172+
def helper(*args, **kwargs):
173+
mock = MagicMock()
174+
mock.recv = MagicMock(side_effect=recv_side_effect)
175+
return mock
176+
177+
return helper
178+
179+
167180
@pytest.fixture
168181
def init_gateway(monkeypatch):
169182
"""Initializes the server for use as a gateway client."""
@@ -321,6 +334,39 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke
321334
assert await is_kernel_running(jp_fetch, k2) is False
322335

323336

337+
@patch("websocket.create_connection", mock_websocket_create_connection(recv_side_effect=Exception))
338+
async def test_kernel_client_response_router_notifies_channel_queue_when_finished(
339+
init_gateway, jp_serverapp, jp_fetch
340+
):
341+
# create
342+
kernel_id = await create_kernel(jp_fetch, "kspec_bar")
343+
344+
# get kernel manager
345+
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)
346+
347+
# create kernel client
348+
kc = km.client()
349+
350+
await ensure_async(kc.start_channels())
351+
352+
with pytest.raises(RuntimeError):
353+
await kc.iopub_channel.get_msg(timeout=10)
354+
355+
all_channels = [
356+
kc.shell_channel,
357+
kc.iopub_channel,
358+
kc.stdin_channel,
359+
kc.hb_channel,
360+
kc.control_channel,
361+
]
362+
assert all(channel.response_router_finished if True else False for channel in all_channels)
363+
364+
await ensure_async(kc.stop_channels())
365+
366+
# delete
367+
await delete_kernel(jp_fetch, kernel_id)
368+
369+
324370
async def test_channel_queue_get_msg_with_invalid_timeout():
325371
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
326372

@@ -352,6 +398,14 @@ async def test_channel_queue_get_msg_with_existing_item():
352398
assert received_message == sent_message
353399

354400

401+
async def test_channel_queue_get_msg_when_response_router_had_finished():
402+
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
403+
queue.response_router_finished = True
404+
405+
with pytest.raises(RuntimeError):
406+
await queue.get_msg()
407+
408+
355409
#
356410
# Test methods below...
357411
#

0 commit comments

Comments
 (0)