diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a39b1b4aeaf20..8bf96c0283378 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -43,12 +43,12 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter +from vllm import utils from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 - def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( model_config.model, @@ -357,6 +357,7 @@ def __init__( self.previous_output = None self.previous_scheduler_outputs = None self.previous_seq_group_metadata_list = None + self.request_outputs = None def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -803,11 +804,7 @@ def _process_sequence_group_outputs( return def _process_model_outputs( - self, - output: GenericSequence[Union[SamplerOutput, PoolerOutput]], - scheduled_seq_groups: List[ScheduledSequenceGroup], - ignored_seq_groups: List[SequenceGroup], - seq_group_metadata_list: List[SequenceGroupMetadata], + self ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Apply the model output to the sequences in the scheduled seq groups. @@ -816,6 +813,11 @@ def _process_model_outputs( now = time.time() + scheduled_seq_groups = self.previous_scheduler_outputs.scheduled_seq_groups + ignored_seq_groups = self.previous_scheduler_outputs.ignored_seq_groups + output = self.previous_output + seq_group_metadata_list = self.previous_seq_group_metadata_list + # Organize outputs by [sequence group][step] instead of # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( @@ -851,7 +853,8 @@ def _process_model_outputs( for seq_group in ignored_seq_groups: request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) - return request_outputs + self.request_outputs = request_outputs + return def _advance_to_next_step( self, @@ -931,12 +934,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: raise NotImplementedError( "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") - request_outputs = None - if (self.previous_output) and (len(self.previous_output) > 0): - request_outputs = self._process_model_outputs( - self.previous_output, self.previous_scheduler_outputs.scheduled_seq_groups, - self.previous_scheduler_outputs.ignored_seq_groups, self.previous_seq_group_metadata_list) - seq_group_metadata_list, scheduler_outputs = self.scheduler[ 0].schedule() @@ -952,10 +949,12 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: running_queue_size=scheduler_outputs.running_queue_size, finished_requests_ids=finished_requests_ids) output = self.model_executor.execute_model( - execute_model_req=execute_model_req) + execute_model_req=execute_model_req, callback_fn=self._process_model_outputs) else: output = [] + # hack to avoid callback function for first step + utils.flag_for_callback_fn = True self.previous_output = output self.previous_scheduler_outputs = scheduler_outputs self.previous_seq_group_metadata_list = seq_group_metadata_list @@ -978,7 +977,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. self.model_executor.stop_remote_worker_execution_loop() - return request_outputs + return self.request_outputs def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: if logger_name in self.stat_loggers: diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4df54a09e5e8c..4897f78e3824f 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -65,7 +65,8 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + execute_model_req: ExecuteModelRequest, + callback_fn = None) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -73,7 +74,7 @@ def execute_model( **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - driver_outputs = self._driver_execute_model(execute_model_req) + driver_outputs = self._driver_execute_model(execute_model_req, callback_fn) assert driver_outputs is not None return driver_outputs diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a848bc70941c1..1c4159a7d2f22 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -75,7 +75,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, execute_model_req: ExecuteModelRequest + self, execute_model_req: ExecuteModelRequest, + callback_fn = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a074b37b..09dd56dd5a564 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -145,14 +145,15 @@ def shutdown(self): worker_monitor.close() def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] + self, execute_model_req: Optional[ExecuteModelRequest], + callback_fn = None ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_model(execute_model_req) + return self.driver_worker.execute_model(execute_model_req, callback_fn) def _run_workers( self, diff --git a/vllm/utils.py b/vllm/utils.py index 51bd72977a226..cb8b41104e02d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -29,6 +29,9 @@ from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger +global flag_for_callback_fn +flag_for_callback_fn = False + logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9c26e0c318b1..a2160e7718283 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -53,6 +53,7 @@ from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available) +from vllm import utils from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -137,7 +138,6 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): # Used for speculative decoding. We do not broadcast it because it is only # used by the driver worker. is_prompt: Optional[bool] = None - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, @@ -1292,6 +1292,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + callback_fn = None ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1380,6 +1381,8 @@ def execute_model( if not self.is_driver_worker: return [] + if utils.flag_for_callback_fn and callback_fn is not None: + callback_fn() # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 46ac16b504bf4..5c8c0cbd2b0f0 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -174,6 +174,7 @@ def execute_model( kv_caches: Optional[List[torch.Tensor]], intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, + callback_fn = None ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e56440693b895..7464920c4664d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -215,7 +215,8 @@ def execute_worker(self, worker_input: WorkerInput) -> None: def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, + callback_fn = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" @@ -273,7 +274,7 @@ def execute_model( output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, intermediate_tensors, - num_steps) + num_steps, callback_fn) if not get_pp_group().is_last_rank: # output is IntermediateTensors