Skip to content

Commit

Permalink
rebase over multi-step and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Aug 20, 2024
1 parent 1b2e046 commit 1356ab0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
13 changes: 8 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,12 +1103,16 @@ def schedule(
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []

# TODO: Combine multi-step and async postprocessor
allow_output_proc_callback: bool = (
self.use_output_proc_callback
and not self.scheduler_config.is_multi_step)

# Create list of scheduled request ids
scheduled_ids: List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]] = []
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
allow_output_proc_callback: bool = False
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
Expand Down Expand Up @@ -1209,10 +1213,9 @@ def schedule(
)
seq_group_metadata_list.append(seq_group_metadata)

if self.use_output_proc_callback:
allow_output_proc_callback = (
allow_output_proc_callback
and self._allow_output_proc_callback(seq_group))
if allow_output_proc_callback:
allow_output_proc_callback = self._allow_output_proc_callback(
seq_group)

scheduled_ids.append((scheduled_seq_group, seq_group_metadata))
# Now that the batch has been created, we can assume all blocks in the
Expand Down
43 changes: 32 additions & 11 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
Expand Down Expand Up @@ -258,6 +258,9 @@ class SchedulerOutputState:
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]]] = None
allow_output_proc_callback: bool = False


class _AsyncLLMEngine(LLMEngine):
Expand Down Expand Up @@ -288,22 +291,27 @@ async def step_async(
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
scheduled_ids = cached_outputs.scheduled_ids
allow_output_proc_callback = cached_outputs.allow_output_proc_callback
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
(seq_group_metadata_list, scheduler_outputs, scheduled_ids, allow_output_proc_callback) = self.scheduler[
virtual_engine].schedule()
(seq_group_metadata_list, scheduler_outputs, scheduled_ids,
allow_output_proc_callback
) = self.scheduler[virtual_engine].schedule()

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
virtual_engine, seq_group_metadata_list, scheduler_outputs,
scheduled_ids, allow_output_proc_callback)

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert scheduled_ids is not None

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
Expand All @@ -328,6 +336,10 @@ async def step_async(
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)

if allow_output_proc_callback:
execute_model_req.callback_fn = self._process_model_outputs

# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
Expand All @@ -350,17 +362,18 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

# Cache results in engine
self.output_queue.append(
(output, scheduled_ids, scheduler_outputs.ignored_seq_groups))
(output, scheduled_ids, scheduler_outputs.ignored_seq_groups))

if (len(output) > 0) and allow_output_proc_callback:
assert len(
output
) == 1, "Multi step decoding does not work with output processor callback" # noqa: E501
self._advance_to_next_step(output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_output_proc_callback:
self._process_model_outputs(is_async=False)
Expand Down Expand Up @@ -398,12 +411,20 @@ def _has_remaining_steps(
def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
scheduler_outputs: SchedulerOutputs,
scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]]],
allow_output_proc_callback: bool) -> None:
v = virtual_engine
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
self.cached_scheduler_outputs[v].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
self.cached_scheduler_outputs[v].scheduled_ids = \
scheduled_ids
self.cached_scheduler_outputs[v].allow_output_proc_callback = \
allow_output_proc_callback
self.cached_scheduler_outputs[v].last_output = None

def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
Expand Down

0 comments on commit 1356ab0

Please sign in to comment.