Skip to content

Commit 65d779a

Browse files
Backport PR #896 on branch 1.x (Notify ChannelQueue that the response router thread is finishing) (#897)
Co-authored-by: Ciprian Anton <[email protected]>
1 parent c3032a0 commit 65d779a

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

jupyter_server/gateway/managers.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,14 @@ def cleanup_resources(self, restart=False):
498498
class ChannelQueue(Queue):
499499

500500
channel_name: Optional[str] = None
501+
response_router_finished: bool
501502

502503
def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger):
503504
super().__init__()
504505
self.channel_name = channel_name
505506
self.channel_socket = channel_socket
506507
self.log = log
508+
self.response_router_finished = False
507509

508510
async def _async_get(self, timeout=None):
509511
if timeout is None:
@@ -516,6 +518,8 @@ async def _async_get(self, timeout=None):
516518
try:
517519
return self.get(block=False)
518520
except Empty:
521+
if self.response_router_finished:
522+
raise RuntimeError("Response router had finished")
519523
if monotonic() > end_time:
520524
raise
521525
await asyncio.sleep(0)
@@ -597,19 +601,21 @@ class GatewayKernelClient(AsyncKernelClient):
597601

598602
# flag for whether execute requests should be allowed to call raw_input:
599603
allow_stdin = False
600-
_channels_stopped = False
601-
_channel_queues: Optional[dict] = {}
604+
_channels_stopped: bool
605+
_channel_queues: Optional[Dict[str, ChannelQueue]]
602606
_control_channel: Optional[ChannelQueue]
603607
_hb_channel: Optional[ChannelQueue]
604608
_stdin_channel: Optional[ChannelQueue]
605609
_iopub_channel: Optional[ChannelQueue]
606610
_shell_channel: Optional[ChannelQueue]
607611

608-
def __init__(self, **kwargs):
612+
def __init__(self, kernel_id, **kwargs):
609613
super().__init__(**kwargs)
610-
self.kernel_id = kwargs["kernel_id"]
614+
self.kernel_id = kernel_id
611615
self.channel_socket: Optional[websocket.WebSocket] = None
612616
self.response_router: Optional[Thread] = None
617+
self._channels_stopped = False
618+
self._channel_queues = {}
613619

614620
# --------------------------------------------------------------------------
615621
# Channel management methods
@@ -642,13 +648,14 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
642648
enable_multithread=True,
643649
sslopt=ssl_options,
644650
)
645-
self.response_router = Thread(target=self._route_responses)
646-
self.response_router.start()
647651

648652
await ensure_async(
649653
super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)
650654
)
651655

656+
self.response_router = Thread(target=self._route_responses)
657+
self.response_router.start()
658+
652659
def stop_channels(self):
653660
"""Stops all the running channels for this kernel.
654661
@@ -751,6 +758,11 @@ def _route_responses(self):
751758
if not self._channels_stopped:
752759
self.log.warning(f"Unexpected exception encountered ({be})")
753760

761+
# Notify channel queues that this thread had finished and no more messages are being received
762+
assert self._channel_queues is not None
763+
for channel_queue in self._channel_queues.values():
764+
channel_queue.response_router_finished = True
765+
754766
self.log.debug("Response router thread exiting...")
755767

756768

tests/test_gateway.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from tornado.httpclient import HTTPRequest, HTTPResponse
1515
from tornado.web import HTTPError
1616

17-
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient
17+
from jupyter_server.gateway.managers import (
18+
ChannelQueue,
19+
GatewayClient,
20+
GatewayKernelManager,
21+
)
1822
from jupyter_server.utils import ensure_async
1923

2024
from .utils import expected_http_error
@@ -164,6 +168,15 @@ async def mock_gateway_request(url, **kwargs):
164168
mock_http_user = "alice"
165169

166170

171+
def mock_websocket_create_connection(recv_side_effect=None):
172+
def helper(*args, **kwargs):
173+
mock = MagicMock()
174+
mock.recv = MagicMock(side_effect=recv_side_effect)
175+
return mock
176+
177+
return helper
178+
179+
167180
@pytest.fixture
168181
def init_gateway(monkeypatch):
169182
"""Initializes the server for use as a gateway client."""
@@ -321,6 +334,39 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke
321334
assert await is_kernel_running(jp_fetch, k2) is False
322335

323336

337+
@patch("websocket.create_connection", mock_websocket_create_connection(recv_side_effect=Exception))
338+
async def test_kernel_client_response_router_notifies_channel_queue_when_finished(
339+
init_gateway, jp_serverapp, jp_fetch
340+
):
341+
# create
342+
kernel_id = await create_kernel(jp_fetch, "kspec_bar")
343+
344+
# get kernel manager
345+
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)
346+
347+
# create kernel client
348+
kc = km.client()
349+
350+
await ensure_async(kc.start_channels())
351+
352+
with pytest.raises(RuntimeError):
353+
await kc.iopub_channel.get_msg(timeout=10)
354+
355+
all_channels = [
356+
kc.shell_channel,
357+
kc.iopub_channel,
358+
kc.stdin_channel,
359+
kc.hb_channel,
360+
kc.control_channel,
361+
]
362+
assert all(channel.response_router_finished if True else False for channel in all_channels)
363+
364+
await ensure_async(kc.stop_channels())
365+
366+
# delete
367+
await delete_kernel(jp_fetch, kernel_id)
368+
369+
324370
async def test_channel_queue_get_msg_with_invalid_timeout():
325371
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
326372

@@ -352,6 +398,14 @@ async def test_channel_queue_get_msg_with_existing_item():
352398
assert received_message == sent_message
353399

354400

401+
async def test_channel_queue_get_msg_when_response_router_had_finished():
402+
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
403+
queue.response_router_finished = True
404+
405+
with pytest.raises(RuntimeError):
406+
await queue.get_msg()
407+
408+
355409
#
356410
# Test methods below...
357411
#

0 commit comments

Comments
 (0)