diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 828d7eed309f6..a7c02322ff02d 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -3,7 +3,6 @@ import asyncio import time import uuid -from contextlib import ExitStack from typing import Dict, List, Optional import pytest @@ -178,7 +177,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_asyncio(monkeypatch): - with monkeypatch.context() as m, ExitStack() as after: + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") # Monkey-patch core engine utility function to test. @@ -195,7 +194,6 @@ async def test_engine_core_client_asyncio(monkeypatch): executor_class=executor_class, log_stats=True, ) - after.callback(client.shutdown) MAX_TOKENS = 20 params = SamplingParams(max_tokens=MAX_TOKENS) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 527aa72833baf..5ffaf63e6cec6 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -8,6 +8,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import Future +from dataclasses import dataclass from threading import Thread from typing import Any, Dict, List, Optional, Type, Union @@ -169,6 +170,31 @@ def add_lora(self, lora_request: LoRARequest) -> None: self.engine_core.add_lora(lora_request) +@dataclass +class BackgroundResources: + """Used as a finalizer for clean shutdown, avoiding + circular reference back to the client object.""" + + ctx: Union[zmq.Context, zmq.asyncio.Context] = None + output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None + input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None + proc_handle: Optional[BackgroundProcHandle] = None + + def __call__(self): + """Clean up background resources.""" + + if self.proc_handle is not None: + self.proc_handle.shutdown() + # ZMQ context termination can hang if the sockets + # aren't explicitly closed first. + if self.output_socket is not None: + self.output_socket.close(linger=0) + if self.input_socket is not None: + self.input_socket.close(linger=0) + if self.ctx is not None: + self.ctx.destroy(linger=0) + + class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -212,21 +238,22 @@ def sigusr1_handler(signum, frame): zmq.asyncio.Context() # type: ignore[attr-defined] if asyncio_mode else zmq.Context()) # type: ignore[attr-defined] - # Note(rob): shutdown function cannot be a bound method, - # else the gc cannot collect the object. - self._finalizer = weakref.finalize(self, lambda x: x.destroy(linger=0), - self.ctx) + # This will ensure resources created so far are closed + # when the client is garbage collected, even if an + # exception is raised mid-construction. + resources = BackgroundResources(ctx=self.ctx) + self._finalizer = weakref.finalize(self, resources) # Paths and sockets for IPC. output_path = get_open_zmq_ipc_path() input_path = get_open_zmq_ipc_path() - self.output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - self.input_socket = make_zmq_socket(self.ctx, input_path, - zmq.constants.PUSH) + resources.output_socket = make_zmq_socket(self.ctx, output_path, + zmq.constants.PULL) + resources.input_socket = make_zmq_socket(self.ctx, input_path, + zmq.constants.PUSH) # Start EngineCore in background process. - self.proc_handle = BackgroundProcHandle( + resources.proc_handle = BackgroundProcHandle( input_path=input_path, output_path=output_path, process_name="EngineCore", @@ -237,13 +264,11 @@ def sigusr1_handler(signum, frame): "log_stats": log_stats, }) + self.output_socket = resources.output_socket + self.input_socket = resources.input_socket self.utility_results: Dict[int, AnyFuture] = {} def shutdown(self): - """Clean up background resources.""" - if hasattr(self, "proc_handle"): - self.proc_handle.shutdown() - self._finalizer()