-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][Core] Add worker_base for v1 worker #12816
Changes from 2 commits
cf26f2c
b400e14
c43e6bc
8e7906d
052cd6b
4f86570
86d0705
cb8a099
889c72a
a68c0fe
d3c7075
47ceccf
c5b222c
7c54cd4
84b2dab
793fe42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import abstractmethod | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.v1.kv_cache_interface import KVCacheSpec | ||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class WorkerBase(WorkerBaseV0): | ||
""" | ||
Abstract class for v1 worker, mainly define some methods for v1. | ||
For methods shared by v0 and v1, define them in v0 WorkerBase | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vllm_config: VllmConfig, | ||
local_rank: int, | ||
rank: int, | ||
distributed_init_method: str, | ||
is_driver_worker: bool = False, | ||
): | ||
""" | ||
Initialize common worker components. | ||
|
||
Args: | ||
vllm_config: Complete vLLM configuration | ||
local_rank: Local device index | ||
rank: Global rank in distributed setup | ||
distributed_init_method: Distributed initialization method | ||
is_driver_worker: Whether this worker handles driver | ||
responsibilities | ||
""" | ||
# Configuration storage | ||
super().__init__(vllm_config=vllm_config) | ||
|
||
self.local_rank = local_rank | ||
self.rank = rank | ||
self.distributed_init_method = distributed_init_method | ||
self.is_driver_worker = is_driver_worker | ||
|
||
# Device and model state | ||
self.device: Optional[torch.device] = None | ||
self.model_runner: Optional[nn.Module] = None | ||
|
||
@abstractmethod | ||
def get_kv_cache_spec(self) -> KVCacheSpec: | ||
"""Get specifications for KV cache implementation.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def compile_or_warm_up_model(self) -> None: | ||
"""Prepare model for execution through compilation/warmup.""" | ||
raise NotImplementedError | ||
|
||
def check_health(self) -> None: | ||
"""Basic health check (override for device-specific checks).""" | ||
return |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,9 @@ | |
import os | ||
import time | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union | ||
from functools import wraps | ||
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type, | ||
TypeVar, Union) | ||
|
||
import cloudpickle | ||
import torch | ||
|
@@ -26,7 +28,68 @@ | |
|
||
logger = init_logger(__name__) | ||
|
||
T = TypeVar('T') | ||
|
||
|
||
def check_implementation(): | ||
""" | ||
A decorator that checks if all abstract methods from the base class | ||
are implemented in the subclass and gives warnings for unimplemented | ||
methods. | ||
""" | ||
|
||
def decorator(cls: Type[T]) -> Type[T]: | ||
|
||
original_init = cls.__init__ | ||
|
||
def warn_unimplemented_methods(self: object): | ||
unimplemented_methods = [] | ||
AoyuQC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for attr_name in dir(self): | ||
# bypass inner method | ||
if attr_name.startswith('_'): | ||
continue | ||
base_method = getattr(self, attr_name) | ||
# bypass method already defined | ||
if getattr(base_method, '_avoid_check', False): | ||
continue | ||
# get the func of callable method | ||
if callable(base_method): | ||
base_method_name = base_method.__func__ | ||
else: | ||
continue | ||
class_method_name = getattr(cls, attr_name, False) | ||
# bypass method defined in sub class | ||
if not class_method_name: | ||
continue | ||
if class_method_name == base_method_name: | ||
unimplemented_methods.append(attr_name) | ||
if unimplemented_methods: | ||
method_names = ','.join(unimplemented_methods) | ||
msg = (f"Methods {method_names} not implemented in {self}") | ||
logger.warning(msg) | ||
|
||
@wraps(original_init) | ||
def wrapped_init(self, *args, **kwargs): | ||
original_init(self, *args, **kwargs) | ||
warn_unimplemented_methods(self) | ||
|
||
cls.__init__ = wrapped_init | ||
return cls | ||
|
||
return decorator | ||
|
||
|
||
def avoid_check(func: Callable[..., T]) -> Callable[..., T]: | ||
|
||
@wraps(func) | ||
def wrapper(*args: Any, **kwargs: Any) -> T: | ||
return func(*args, **kwargs) | ||
|
||
wrapper._avoid_check = True # type: ignore | ||
return wrapper | ||
|
||
|
||
@check_implementation() | ||
class WorkerBase(ABC): | ||
"""Worker interface that allows vLLM to cleanly separate implementations for | ||
different hardware. Also abstracts control plane communication, e.g., to | ||
|
@@ -60,28 +123,26 @@ def init_device(self) -> None: | |
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
"""Determine the number of available blocks for the GPU KV cache and | ||
swappable CPU KV cache. | ||
|
||
The implementation may run profiling or other heuristics to determine | ||
the size of caches. | ||
|
||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks | ||
are blocks that are "active" on the device and can be appended to. | ||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be | ||
appended to. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def initialize_cache(self, num_gpu_blocks: int, | ||
num_cpu_blocks: int) -> None: | ||
"""Initialize the KV cache with the given size in blocks. | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_model(self) -> nn.Module: | ||
raise NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, doing so prevents subclasses from being instantiated if they haven't implemented the abstract method, instead of only erroring out when the method is called. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's too restrictive, and hurts the development. erroring out when the method is called looks enough. we can add the implementation step by step. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree to remove unnecessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid forgetting to implement them at the end, can we add a test that prints out which abstract methods remain to be implemented for each worker subclass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This confuses me. What's the difference between throwing error when the method is called ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can make the test print out warnings instead of failing outright. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is that users won't get unnecessary warnings while developers can remain aware of this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi, @DarkLight1337 and @youkaichao . I propose a decorator to throw warnings for methods not implemented in sub class. it will show something like this. I still put one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me. You can keep the |
||
|
||
def load_model(self) -> None: | ||
"""Load model onto target device.""" | ||
raise NotImplementedError | ||
|
||
def execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> Optional[List[SamplerOutput]]: | ||
raise NotImplementedError | ||
|
||
@avoid_check | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel this is too complicated, we need to decorate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably need to go back to the original point of the discussion which is making development easier, right? Thanks @youkaichao for the reminder, I initially inherited v0 workerbase ( It's easy if you just remove
As a developer, you can choose to implement it or not, and you don't need to go through trial and error to find a way to achieve it. Right? So the final experience is that we only keep the key abstract methods, and tell the developers the rest through reminders BTW, @DarkLight1337 , we cannot use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about keeping abstractmethod but effectively disabling them by removing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@DarkLight1337 this makes life easier and won't throw errors if we don't implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I see. This makes sense to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey guys, I really love the open source community's discussion atmosphere, but I also hope this PR doesn't drop halfway. In this case, I don't know if you members have any mechanisms? Like voting or something? No matter what kind of change it is, I might need a clear and uniform suggestion. Thank you! @youkaichao @DarkLight1337 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's gather some more thoughts before proceeding. V1 is high priority so it will get resolved soon! |
||
def start_worker_execution_loop(self) -> None: | ||
"""Execute model loop in parallel worker. | ||
|
||
|
@@ -94,40 +155,43 @@ def start_worker_execution_loop(self) -> None: | |
if output is None: | ||
return None | ||
|
||
@abstractmethod | ||
def get_model(self) -> nn.Module: | ||
raise NotImplementedError | ||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
"""Determine the number of available blocks for the GPU KV cache and | ||
swappable CPU KV cache. | ||
|
||
@abstractmethod | ||
def execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> Optional[List[SamplerOutput]]: | ||
The implementation may run profiling or other heuristics to determine | ||
the size of caches. | ||
|
||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks | ||
are blocks that are "active" on the device and can be appended to. | ||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be | ||
appended to. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_cache_block_size_bytes(self) -> int: | ||
"""Return the size of a single cache block, in bytes. Used in | ||
speculative decoding. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def add_lora(self, lora_request: LoRARequest) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def remove_lora(self, lora_id: int) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def pin_lora(self, lora_id: int) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def list_loras(self) -> Set[int]: | ||
raise NotImplementedError | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
"""Get vocabulary size from model configuration.""" | ||
return self.model_config.get_vocab_size() | ||
|
||
|
||
class DelegateWorkerBase(WorkerBase): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this part is too complicated to understand. can we just have
raise NotImplementedError
?if we don't get these errors, it means this function is not called right now, and it's fine.
if we get these errors, the developers will definitely see it and implement it, if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since both @youkaichao and @DarkLight1337 are members of this project, I hope you can agree on a certain level of direction and positioning so that we developers can continue to contribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still
raise NotImplementedError
when the method is used, this code is just to log additional warnings if some methods aren't implemented.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It becomes easier to keep track of the feature completeness as we know which features are still missing without waiting for someone to try using that feature and then open a "bug report".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this helps to remind developers that some base class methods aren't implemented in inherited sub classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's an overkill and complicates the codebase a lot