Skip to content

Commit

Permalink
[V1][Core] Add worker_base for v1 worker (vllm-project#12816)
Browse files Browse the repository at this point in the history
Signed-off-by: Aoyu <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: Aoyu <[email protected]>
Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
3 people authored and I746365 committed Feb 15, 2025
1 parent 5a27c43 commit 08c1cd3
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 52 deletions.
43 changes: 43 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,3 +2220,46 @@ def import_pynvml():
"""
import vllm.third_party.pynvml as pynvml
return pynvml


def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]:
"""
A replacement for `abc.ABC`.
When we use `abc.ABC`, subclasses will fail to instantiate
if they do not implement all abstract methods.
Here, we only require `raise NotImplementedError` in the
base class, and log a warning if the method is not implemented
in the subclass.
"""

original_init = cls.__init__

def find_unimplemented_methods(self: object):
unimplemented_methods = []
for attr_name in dir(self):
# bypass inner method
if attr_name.startswith('_'):
continue

try:
attr = getattr(self, attr_name)
# get the func of callable method
if callable(attr):
attr_func = attr.__func__
except AttributeError:
continue
src = inspect.getsource(attr_func)
if "NotImplementedError" in src:
unimplemented_methods.append(attr_name)
if unimplemented_methods:
method_names = ','.join(unimplemented_methods)
msg = (f"Methods {method_names} not implemented in {self}")
logger.warning(msg)

@wraps(original_init)
def wrapped_init(self, *args, **kwargs) -> None:
original_init(self, *args, **kwargs)
find_unimplemented_methods(self)

type.__setattr__(cls, '__init__', wrapped_init)
return cls
28 changes: 9 additions & 19 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput


class Worker:
class Worker(WorkerBase):

def __init__(
self,
Expand All @@ -39,23 +40,11 @@ def __init__(
is_driver_worker: bool = False,
):

# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config

self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand Down Expand Up @@ -126,7 +115,8 @@ def init_device(self):
set_random_seed(self.model_config.seed)

# Construct the model runner
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)

def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
Expand Down
63 changes: 63 additions & 0 deletions vllm/v1/worker/worker_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0

logger = init_logger(__name__)


class WorkerBase(WorkerBaseV0):
"""
Abstract class for v1 worker, mainly define some methods for v1.
For methods shared by v0 and v1, define them in v0 WorkerBase
"""

def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
"""
Initialize common worker components.
Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
# Configuration storage
super().__init__(vllm_config=vllm_config)

self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker

# Device and model state
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None

def get_kv_cache_spec(self) -> KVCacheSpec:
"""Get specifications for KV cache implementation."""
raise NotImplementedError

def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError

def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
71 changes: 38 additions & 33 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import os
import time
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

import cloudpickle
Expand All @@ -19,15 +19,17 @@
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
update_environment_variables)
update_environment_variables,
warn_for_unimplemented_methods)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)

logger = init_logger(__name__)


class WorkerBase(ABC):
@warn_for_unimplemented_methods
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
Expand All @@ -53,35 +55,31 @@ def __init__(
from vllm.platforms import current_platform
self.current_platform = current_platform

@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError

@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError

@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError

def get_model(self) -> nn.Module:
raise NotImplementedError

def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError

def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError

def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
Expand All @@ -94,40 +92,43 @@ def start_worker_execution_loop(self) -> None:
if output is None:
return None

@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError

@abstractmethod
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError

@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError

@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError

@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()


class DelegateWorkerBase(WorkerBase):
"""
Expand Down Expand Up @@ -156,6 +157,10 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

def load_model(self) -> None:
"""Load model onto target device."""
self.worker.load_model()

def get_model(self) -> nn.Module:
return self.worker.get_model()

Expand Down

0 comments on commit 08c1cd3

Please sign in to comment.