Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Add worker_base for v1 worker #12816

Merged
merged 16 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput


class Worker:
class Worker(WorkerBase):

def __init__(
self,
Expand All @@ -40,23 +41,11 @@ def __init__(
is_driver_worker: bool = False,
):

# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config

self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
Expand Down Expand Up @@ -127,7 +116,8 @@ def init_device(self):
set_random_seed(self.model_config.seed)

# Construct the model runner
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)

def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
Expand Down
66 changes: 66 additions & 0 deletions vllm/v1/worker/worker_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0

from abc import abstractmethod
from typing import Optional

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0

logger = init_logger(__name__)


class WorkerBase(WorkerBaseV0):
"""
Abstract class for v1 worker, mainly define some methods for v1.
For methods shared by v0 and v1, define them in v0 WorkerBase
"""

def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
"""
Initialize common worker components.

Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
# Configuration storage
super().__init__(vllm_config=vllm_config)

self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker

# Device and model state
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None

@abstractmethod
def get_kv_cache_spec(self) -> KVCacheSpec:
"""Get specifications for KV cache implementation."""
raise NotImplementedError

@abstractmethod
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError

def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
124 changes: 94 additions & 30 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from functools import wraps
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)

import cloudpickle
import torch
Expand All @@ -26,7 +28,68 @@

logger = init_logger(__name__)

T = TypeVar('T')


def check_implementation():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this part is too complicated to understand. can we just have raise NotImplementedError ?

if we don't get these errors, it means this function is not called right now, and it's fine.

if we get these errors, the developers will definitely see it and implement it, if necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since both @youkaichao and @DarkLight1337 are members of this project, I hope you can agree on a certain level of direction and positioning so that we developers can continue to contribute.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still raise NotImplementedError when the method is used, this code is just to log additional warnings if some methods aren't implemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is just to log additional warnings if some methods aren't implemented

why do we need this?

Copy link
Member

@DarkLight1337 DarkLight1337 Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes easier to keep track of the feature completeness as we know which features are still missing without waiting for someone to try using that feature and then open a "bug report".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this helps to remind developers that some base class methods aren't implemented in inherited sub classes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's an overkill and complicates the codebase a lot

"""
A decorator that checks if all abstract methods from the base class
are implemented in the subclass and gives warnings for unimplemented
methods.
"""

def decorator(cls: Type[T]) -> Type[T]:

original_init = cls.__init__

def warn_unimplemented_methods(self: object):
unimplemented_methods = []
for attr_name in dir(self):
# bypass inner method
if attr_name.startswith('_'):
continue
base_method = getattr(self, attr_name)
# bypass method already defined
if getattr(base_method, '_avoid_check', False):
continue
# get the func of callable method
if callable(base_method):
base_method_name = base_method.__func__
else:
continue
class_method_name = getattr(cls, attr_name, False)
# bypass method defined in sub class
if not class_method_name:
continue
if class_method_name == base_method_name:
unimplemented_methods.append(attr_name)
if unimplemented_methods:
method_names = ','.join(unimplemented_methods)
msg = (f"Methods {method_names} not implemented in {self}")
logger.warning(msg)

@wraps(original_init)
def wrapped_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
warn_unimplemented_methods(self)

cls.__init__ = wrapped_init
return cls

return decorator


def avoid_check(func: Callable[..., T]) -> Callable[..., T]:

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
return func(*args, **kwargs)

wrapper._avoid_check = True # type: ignore
return wrapper


@check_implementation()
class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
Expand Down Expand Up @@ -60,28 +123,26 @@ def init_device(self) -> None:
"""
raise NotImplementedError

@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.

The implementation may run profiling or other heuristics to determine
the size of caches.

Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError

@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError

def get_model(self) -> nn.Module:
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we have raise NotImplementedError , do we still need the @abstractmethod ?

Copy link
Member

@DarkLight1337 DarkLight1337 Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, doing so prevents subclasses from being instantiated if they haven't implemented the abstract method, instead of only erroring out when the method is called.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, doing so prevents subclasses from being instantiated if they haven't implemented the abstract method, instead of only erroring out when the method is called.

I think that's too restrictive, and hurts the development. erroring out when the method is called looks enough. we can add the implementation step by step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to remove unnecessary @abstractmethod and only raise NotImplementedError. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid forgetting to implement them at the end, can we add a test that prints out which abstract methods remain to be implemented for each worker subclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This confuses me. What's the difference between throwing error when the method is called (@abstractmethod) and failing to pass the test for checking abstract methods implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make the test print out warnings instead of failing outright.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that users won't get unnecessary warnings while developers can remain aware of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @DarkLight1337 and @youkaichao . I propose a decorator to throw warnings for methods not implemented in sub class. it will show something like this.
WARNING 02-10 12:57:48 worker_base.py:67] Methods add_lora,determine_num_available_blocks,get_cache_block_size_bytes,list_loras,pin_lora,remove_lora not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7a025b3297b0>

I still put one @abstractmethod there to pass the pre-commit inspection temporarily. If everyone agrees with this direction, I will make further changes to remove @abstractmethod meeting the needs of users and developers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. You can keep the abstractmethod for get_model as it's the most straightforward one to implement.


def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError

def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError

@avoid_check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is too complicated, we need to decorate @avoid_check for every method that has a common implementation in the base class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check __isabstractmethod__ attribute to only select the methods that are decorated by @abstractmethod.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to go back to the original point of the discussion which is making development easier, right? Thanks @youkaichao for the reminder, I initially inherited v0 workerbase (vllm.worker.worker_base.WorkerBase). However, this base has some @abstractmethod that are not needed in v1 right now, such as add_lora,remove_lora.

It's easy if you just remove @abstractmethod, but it's probably not development-friendly because developers always want a good reminder for what methods aren't implemented right now. Therefore, with the following user experience, there are some unimplemented methods in the v1 GPU.

WARNING 02-10 12:57:48 worker_base.py:67] Methods add_lora,determine_num_available_blocks,get_cache_block_size_bytes,list_loras,pin_lora,remove_lora not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7a025b3297b0>

As a developer, you can choose to implement it or not, and you don't need to go through trial and error to find a way to achieve it. Right? So the final experience is that we only keep the key abstract methods, and tell the developers the rest through reminders

BTW, @DarkLight1337 , we cannot use __isabstractmethod__ since we removed most @abstractmethod. And there is only one start_worker_execution_loop needing @avoid_check, so this shouldn't be a huge amount of work? @youkaichao

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about keeping abstractmethod but effectively disabling them by removing ABC base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about keeping abstractmethod but effectively disabling them by removing ABC base class?

@DarkLight1337 this makes life easier and won't throw errors if we don't implement @abstractmethod in base class. But there was no reminder of this. If this is an approach you agree with, isn't @youkaichao 's idea a more straightforward approach that doesn't even require an @abstractmethod?

Copy link
Member

@DarkLight1337 DarkLight1337 Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently check_implementation replaces the functionality of ABC (which throws errors instead of warnings if there are unimplemented methods). Basically, we use abstractmethod (soft check at initialization) to remind ourselves which methods are still missing, and use raise NotImplementedError (hard check at call site) to prevent real usage. Both can exist together and complement each other.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently check_implementation replaces the functionality of ABC (which throws errors instead of warnings if there are unimplemented methods). Basically, we use abstractmethod (soft check at initialization) + raise NotImplementedError (hard check at call site) to remind ourselves which methods are still missing.

I see. This makes sense to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey guys, I really love the open source community's discussion atmosphere, but I also hope this PR doesn't drop halfway. In this case, I don't know if you members have any mechanisms? Like voting or something? No matter what kind of change it is, I might need a clear and uniform suggestion. Thank you! @youkaichao @DarkLight1337

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's gather some more thoughts before proceeding. V1 is high priority so it will get resolved soon!

def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.

Expand All @@ -94,40 +155,43 @@ def start_worker_execution_loop(self) -> None:
if output is None:
return None

@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.

@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
The implementation may run profiling or other heuristics to determine
the size of caches.

Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError

@abstractmethod
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError

@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError

@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError

@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()


class DelegateWorkerBase(WorkerBase):
"""
Expand Down
Loading