Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notify ChannelQueue that the response router thread is finishing #896

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,14 @@ def cleanup_resources(self, restart=False):
class ChannelQueue(Queue):

channel_name: Optional[str] = None
response_router_finished: bool

def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger):
super().__init__()
self.channel_name = channel_name
self.channel_socket = channel_socket
self.log = log
self.response_router_finished = False

async def _async_get(self, timeout=None):
if timeout is None:
Expand All @@ -516,6 +518,8 @@ async def _async_get(self, timeout=None):
try:
return self.get(block=False)
except Empty:
if self.response_router_finished:
raise RuntimeError("Response router had finished")
if monotonic() > end_time:
raise
await asyncio.sleep(0)
Expand Down Expand Up @@ -598,16 +602,16 @@ class GatewayKernelClient(AsyncKernelClient):
# flag for whether execute requests should be allowed to call raw_input:
allow_stdin = False
_channels_stopped: bool
_channel_queues: Optional[dict]
_channel_queues: Optional[Dict[str, ChannelQueue]]
_control_channel: Optional[ChannelQueue]
_hb_channel: Optional[ChannelQueue]
_stdin_channel: Optional[ChannelQueue]
_iopub_channel: Optional[ChannelQueue]
_shell_channel: Optional[ChannelQueue]

def __init__(self, **kwargs):
def __init__(self, kernel_id, **kwargs):
super().__init__(**kwargs)
self.kernel_id = kwargs["kernel_id"]
self.kernel_id = kernel_id
self.channel_socket: Optional[websocket.WebSocket] = None
self.response_router: Optional[Thread] = None
self._channels_stopped = False
Expand Down Expand Up @@ -644,13 +648,14 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
enable_multithread=True,
sslopt=ssl_options,
)
self.response_router = Thread(target=self._route_responses)
self.response_router.start()

await ensure_async(
super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)
)

self.response_router = Thread(target=self._route_responses)
self.response_router.start()

def stop_channels(self):
"""Stops all the running channels for this kernel.

Expand Down Expand Up @@ -753,6 +758,11 @@ def _route_responses(self):
if not self._channels_stopped:
self.log.warning(f"Unexpected exception encountered ({be})")

# Notify channel queues that this thread had finished and no more messages are being received
assert self._channel_queues is not None
for channel_queue in self._channel_queues.values():
channel_queue.response_router_finished = True

self.log.debug("Response router thread exiting...")


Expand Down
56 changes: 55 additions & 1 deletion tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.web import HTTPError

from jupyter_server.gateway.managers import ChannelQueue, GatewayClient
from jupyter_server.gateway.managers import (
ChannelQueue,
GatewayClient,
GatewayKernelManager,
)
from jupyter_server.utils import ensure_async

from .utils import expected_http_error
Expand Down Expand Up @@ -164,6 +168,15 @@ async def mock_gateway_request(url, **kwargs):
mock_http_user = "alice"


def mock_websocket_create_connection(recv_side_effect=None):
def helper(*args, **kwargs):
mock = MagicMock()
mock.recv = MagicMock(side_effect=recv_side_effect)
return mock

return helper


@pytest.fixture
def init_gateway(monkeypatch):
"""Initializes the server for use as a gateway client."""
Expand Down Expand Up @@ -321,6 +334,39 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke
assert await is_kernel_running(jp_fetch, k2) is False


@patch("websocket.create_connection", mock_websocket_create_connection(recv_side_effect=Exception))
async def test_kernel_client_response_router_notifies_channel_queue_when_finished(
init_gateway, jp_serverapp, jp_fetch
):
# create
kernel_id = await create_kernel(jp_fetch, "kspec_bar")

# get kernel manager
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)

# create kernel client
kc = km.client()

await ensure_async(kc.start_channels())

with pytest.raises(RuntimeError):
await kc.iopub_channel.get_msg(timeout=10)

all_channels = [
kc.shell_channel,
kc.iopub_channel,
kc.stdin_channel,
kc.hb_channel,
kc.control_channel,
]
assert all(channel.response_router_finished if True else False for channel in all_channels)

await ensure_async(kc.stop_channels())

# delete
await delete_kernel(jp_fetch, kernel_id)


async def test_channel_queue_get_msg_with_invalid_timeout():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())

Expand Down Expand Up @@ -352,6 +398,14 @@ async def test_channel_queue_get_msg_with_existing_item():
assert received_message == sent_message


async def test_channel_queue_get_msg_when_response_router_had_finished():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
queue.response_router_finished = True

with pytest.raises(RuntimeError):
await queue.get_msg()


#
# Test methods below...
#
Expand Down