From 051868df56ac70429712f7d83d5df2658898ee6a Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 23 Aug 2024 15:05:57 +0000 Subject: [PATCH] refactor the append_token_id to its original form --- vllm/engine/llm_engine.py | 15 ++++----------- vllm/engine/output_processor/single_step.py | 12 +++++++----- vllm/sequence.py | 10 +++------- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c8591e893eb6d..6795b38194b12 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1313,18 +1313,11 @@ def _advance_to_next_step( "output_proc_callback expects a single sample" " (i.e sampling_params.n == 1 and no " "sampling_params.best_of > 1)") - seq_group_metadata.is_prompt = False - seq_output = sequence_group_outputs.samples[0] + sample = sequence_group_outputs.samples[0] - # NOTE: Beam search is not supported, so we can assume that - # parent_seq_id == seq_id. - seq_data = seq_group_metadata.seq_data[ - seq_output.parent_seq_id] - - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] - - seq_data.append_token_id(token_id, token_logprob.logprob) + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index b9c1b8cf553c8..104f51ea111e3 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -83,10 +83,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if sampling_params.n == 1 and not sampling_params.use_beam_search: if len(outputs.samples) > 0: sample = outputs.samples[0] - # only have one sequence seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs, - not is_async) + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) @@ -105,6 +104,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # is still not finished return + # TODO: Add support for below cases for async + assert not is_async + # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) @@ -140,14 +142,14 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, - child_sample.logprobs, not is_async) + child_sample.logprobs) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs, not is_async) + last_child_sample.logprobs) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/sequence.py b/vllm/sequence.py index c5b52a1d55bfd..a94364c296ad3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -474,15 +474,11 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, - token_id: int, - logprobs: Dict[int, Logprob], - update_seq_data: bool = True) -> None: + def append_token_id(self, token_id: int, logprobs: Dict[int, + Logprob]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - # Only do this when output proc callback is not used - if update_seq_data: - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.data.append_token_id(token_id, logprobs[token_id].logprob) def get_len(self) -> int: return self.data.get_len()