From 64b8b0b3f2b86a1fefcc8a4c0d22ecbbe32ea6ea Mon Sep 17 00:00:00 2001 From: meghagarwal <16129366+megha95@users.noreply.github.com> Date: Mon, 12 Aug 2024 19:34:44 +0000 Subject: [PATCH] update stop status --- vllm/core/scheduler.py | 7 +++--- vllm/engine/llm_engine.py | 49 +++++++++++++++++++++++++++------------ vllm/entrypoints/llm.py | 3 --- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2e32de121db7a..1523d9a9a0dfa 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1041,11 +1041,12 @@ def free_finished_seq_groups(self) -> None: # This list will be used to update the Mamba cache in the # next step. self._finished_requests_ids.append(seq_group.request_id) - # Free finished seqs - for seq in seq_group.get_seqs(): - self.free_seq(seq) else: remaining.append(seq_group) + # Free finished seqs + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) self.running = remaining def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c2e20dd31b3d8..12c46e522b1ae 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -358,7 +358,7 @@ def __init__( self.previous_output = None self.previous_scheduler_outputs = None self.previous_seq_group_metadata_list = None - self.request_outputs = None + self.request_outputs = [] def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -857,6 +857,21 @@ def _process_model_outputs( self.request_outputs = request_outputs return + def _update_stop_criteria(self, + seq: Sequence, + sampling_params: SamplingParams): + # Check if the sequence has reached max_model_len. Or if the sequence has reached max_tokens. + if (seq.get_output_len() == sampling_params.max_tokens) or (seq.get_len() >= self.scheduler_config.max_model_len): + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return seq + def _advance_to_next_step( self, output: List[SamplerOutput], @@ -865,22 +880,26 @@ def _advance_to_next_step( sequences. This is normally done inside output processor, but it is required if the worker is to perform async forward pass to next step. """ - for seq_group_metadata, sequence_group_outputs in zip( - seq_group_metadata_list, output): + scheduled_seq_groups = self.previous_scheduler_outputs.scheduled_seq_groups + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in zip( + seq_group_metadata_list, output, scheduled_seq_groups): assert len(sequence_group_outputs.samples) <= 1, \ "sampling_params.n > 1 and sampling_params.best_of > 1 not supported with output proc callback" - if len(sequence_group_outputs.samples) == 1: - seq_group_metadata.is_prompt = False - seq_output = sequence_group_outputs.samples[0] - # NOTE: Beam search is not supported, so we can assume that - # parent_seq_id == seq_id. - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] - - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] - - seq.update_num_computed_tokens(seq_group_metadata.token_chunk_size) - seq.append_token_id(token_id, token_logprob.logprob) + seq_group = scheduled_seq_group.seq_group + for seq in seq_group.get_seqs(): + self._update_stop_criteria(seq, seq_group.sampling_params) + if not seq_group.is_finished(): + if len(sequence_group_outputs.samples) == 1: + seq_group_metadata.is_prompt = False + seq_output = sequence_group_outputs.samples[0] + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + seq.update_num_computed_tokens(seq_group_metadata.token_chunk_size) + seq.append_token_id(token_id, token_logprob.logprob) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e9293e6ac8ee1..678e7ecf07235 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -611,9 +611,6 @@ def _run_engine( total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() - # HACK: no output returned in first step - if not step_outputs: - continue for output in step_outputs: if output.finished: outputs.append(output)