From 0042147f883b082aab28237748151657f9e198d0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 14 Sep 2024 16:58:31 -0700 Subject: [PATCH] [TPU] Implement multi-step scheduling (#8489) --- vllm/config.py | 2 +- vllm/executor/ray_tpu_executor.py | 8 +- vllm/executor/tpu_executor.py | 16 +- vllm/worker/multi_step_tpu_worker.py | 105 +++++++++++++ vllm/worker/tpu_model_runner.py | 224 +++++++++++++++++++-------- 5 files changed, 279 insertions(+), 76 deletions(-) create mode 100644 vllm/worker/multi_step_tpu_worker.py diff --git a/vllm/config.py b/vllm/config.py index 9684cea813134..89cffc8b306b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -379,7 +379,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if self.enforce_eager: + if device_config.device_type == "cuda" and self.enforce_eager: logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 8c8b5f741488b..732b69d6e5954 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -68,8 +68,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", ) assert self.speculative_config is None - worker_module_name = "vllm.worker.tpu_worker" - worker_class_name = "TPUWorker" + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_tpu_worker" + worker_class_name = "MultiStepTPUWorker" + else: + worker_module_name = "vllm.worker.tpu_worker" + worker_class_name = "TPUWorker" # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 0af8ba41e24d5..972649dedf33e 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -62,11 +62,17 @@ def _create_worker( rank: int = 0, distributed_init_method: Optional[str] = None, ): - from vllm.worker.tpu_worker import TPUWorker - - worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank, - distributed_init_method)) - return worker + if self.scheduler_config.is_multi_step: + from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker + worker = MultiStepTPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker + else: + from vllm.worker.tpu_worker import TPUWorker + + worker = TPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker def initialize_cache( self, diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py new file mode 100644 index 0000000000000..e654f7172b266 --- /dev/null +++ b/vllm/worker/multi_step_tpu_worker.py @@ -0,0 +1,105 @@ +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.tpu_model_runner import ModelInputForTPU +from vllm.worker.tpu_worker import TPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepTPUWorker(TPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForTPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForTPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + broadcast_tensor_dict({}, src=0) + return None + + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {} diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index db306bc743d3a..575769ca1aa4a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -51,6 +51,8 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int best_of: List[int] seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True virtual_engine: int = 0 async_callback: Optional[Callable] = None @@ -65,6 +67,8 @@ def as_broadcastable_tensor_dict( "num_samples": self.num_samples, "best_of": self.best_of, "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -118,6 +122,7 @@ def __init__( self.block_size, False, ) + self.cached_step_outputs: List[torch.Tensor] = [] def load_model(self) -> None: self.device = self.device_config.device @@ -518,97 +523,159 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None - if num_steps > 1: - raise ValueError( - "TPUModelRunner does not support multi-step execution.") - - def _execute_model(*args): - """Move input args from CPU to device and execute the model.""" - - new_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - arg = arg.to(self.device) - elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = arg.slot_mapping.to(self.device) - if getattr(arg, "block_tables", None) is not None: - arg.block_tables = arg.block_tables.to(self.device) - if getattr(arg, "context_lens", None) is not None: - arg.context_lens = arg.context_lens.to(self.device) - new_args.append(arg) - return self.model(*new_args, is_prompt=is_prompt) - - num_prefills = model_input.attn_metadata.num_prefills - is_prompt = num_prefills > 0 + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 if is_prompt: + assert num_steps == 1 # NOTE(woosuk): Since the FlashAttention kernel does not support # ragged inputs, we split the prompts into different batches and # process them separately. This is a temporary hack that should be # optimized by using SplashAttention. - next_token_ids = [] orig_slot_mapping = model_input.attn_metadata.slot_mapping batch_size = model_input.input_lens.shape[0] start_idx = 0 + next_token_ids = [] for i in range(batch_size): # Get the actual prefill_len. prefill_len = model_input.input_lens[i:i + 1].item() prefill_len = _get_padded_prefill_len(prefill_len) end_idx = start_idx + prefill_len - model_input.attn_metadata.slot_mapping = orig_slot_mapping[ - None, start_idx:end_idx] - model_input.attn_metadata.num_prefills = 1 - output_token_ids = _execute_model( - model_input.token_ids[None, start_idx:end_idx], - model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], model_input.p[i:i + 1], - model_input.num_samples, kv_caches) - if i == 0 and model_input.async_callback is not None: - model_input.async_callback() - # Retrieve the outputs to CPU. - next_token_ids += output_token_ids.cpu().tolist() + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=True) + next_token_ids.append(output_token_ids[0]) start_idx = end_idx - else: - # Execute the model. - output_token_ids = _execute_model( - model_input.token_ids, model_input.position_ids, - model_input.attn_metadata, model_input.input_lens, - model_input.t, model_input.p, model_input.num_samples, - kv_caches) + if model_input.async_callback is not None: model_input.async_callback() # Retrieve the outputs to CPU. - next_token_ids = output_token_ids.cpu().tolist() - - # NOTE(woosuk): Minimal code to construct the sampler outputs. - # The TPU backend does not reuse the sampler, since the TPU backend - # does not support the advanced sampling parameters such as logprobs. - zero_logprob = Logprob(0.0) - batch_idx = 0 - sampler_outputs = [] - for seq_group in model_input.seq_groups: - seq_ids = seq_group - seq_outputs = [] - if is_prompt: + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group assert len(seq_ids) == 1 seq_id = seq_ids[0] - for i in range(model_input.best_of[batch_idx]): - next_token_id = next_token_ids[batch_idx][i] + seq_outputs = [] + for j in range(model_input.best_of[i]): + next_token_id = next_token_ids[i][j] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) - batch_idx += 1 - else: - for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - batch_idx += 1 - sampler_outputs.append( - CompletionSequenceGroupOutput(seq_outputs, None)) - return [SamplerOutput(sampler_outputs)] + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=False) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): @@ -756,3 +823,24 @@ def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs)