From cf26f2cac473cd52def0068b431e186a0840a941 Mon Sep 17 00:00:00 2001 From: Aoyu Date: Tue, 11 Feb 2025 08:16:19 +0000 Subject: [PATCH 01/13] [V1][Core] Add worker_base for v1 worker 1. reuse WorkerBase in vllm.worker.worker_base 2. remove unnecessary abstract methods and only give warnings for unimplemented methods Signed-off-by: Aoyu --- vllm/v1/worker/gpu_worker.py | 28 +++----- vllm/v1/worker/worker_base.py | 66 +++++++++++++++++++ vllm/worker/worker_base.py | 121 +++++++++++++++++++++++++--------- 3 files changed, 166 insertions(+), 49 deletions(-) create mode 100644 vllm/v1/worker/worker_base.py diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0adb69073397c..fd832dd64e96c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -22,6 +22,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -29,7 +30,7 @@ from vllm.v1.core.scheduler import SchedulerOutput -class Worker: +class Worker(WorkerBase): def __init__( self, @@ -40,23 +41,11 @@ def __init__( is_driver_worker: bool = False, ): - # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -127,7 +116,8 @@ def init_device(self): set_random_seed(self.model_config.seed) # Construct the model runner - self.model_runner = GPUModelRunner(self.vllm_config, self.device) + self.model_runner: GPUModelRunner = GPUModelRunner( + self.vllm_config, self.device) def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py new file mode 100644 index 0000000000000..9d11da93b5e93 --- /dev/null +++ b/vllm/v1/worker/worker_base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 + +logger = init_logger(__name__) + + +class WorkerBase(WorkerBaseV0): + """ + Abstract class for v1 worker, mainly define some methods for v1. + For methods shared by v0 and v1, define them in v0 WorkerBase + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + """ + Initialize common worker components. + + Args: + vllm_config: Complete vLLM configuration + local_rank: Local device index + rank: Global rank in distributed setup + distributed_init_method: Distributed initialization method + is_driver_worker: Whether this worker handles driver + responsibilities + """ + # Configuration storage + super().__init__(vllm_config=vllm_config) + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + # Device and model state + self.device: Optional[torch.device] = None + self.model_runner: Optional[nn.Module] = None + + @abstractmethod + def get_kv_cache_spec(self) -> KVCacheSpec: + """Get specifications for KV cache implementation.""" + raise NotImplementedError + + @abstractmethod + def compile_or_warm_up_model(self) -> None: + """Prepare model for execution through compilation/warmup.""" + raise NotImplementedError + + def check_health(self) -> None: + """Basic health check (override for device-specific checks).""" + return \ No newline at end of file diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdbb2..ae7763452bde8 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -4,7 +4,9 @@ import os import time from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from functools import wraps +from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import cloudpickle import torch @@ -27,6 +29,64 @@ logger = init_logger(__name__) +def check_implementation(): + """ + A decorator that checks if all abstract methods from the base class + are implemented in the subclass and gives warnings for unimplemented + methods. + """ + + def decorator(cls: Type): + original_init = cls.__init__ + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith('_'): + continue + base_method = getattr(self, attr_name) + # bypass method already defined + if getattr(base_method, '_avoid_check', False): + continue + # get the func of callable method + if callable(base_method): + base_method_name = base_method.__func__ + else: + continue + class_method = getattr(cls, attr_name, False) + # bypass method defined in sub class + if not class_method: + continue + if class_method == base_method_name: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ','.join(unimplemented_methods) + msg = (f"Methods {method_names} not implemented in {self}") + logger.warning(msg) + + cls.__init__ = wrapped_init + return cls + + return decorator + + +T = TypeVar('T') + + +def avoid_check(func: Callable[..., T]) -> Callable[..., T]: + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + return func(*args, **kwargs) + + wrapper._avoid_check = True # type: ignore + return wrapper + + +@check_implementation() class WorkerBase(ABC): """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to @@ -60,28 +120,26 @@ def init_device(self) -> None: """ raise NotImplementedError - @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - @abstractmethod def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache with the given size in blocks. """ raise NotImplementedError + def get_model(self) -> nn.Module: + raise NotImplementedError + + def load_model(self) -> None: + """Load model onto target device.""" + raise NotImplementedError + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + raise NotImplementedError + + @avoid_check def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. @@ -94,40 +152,43 @@ def start_worker_execution_loop(self) -> None: if output is None: return None - @abstractmethod - def get_model(self) -> nn.Module: - raise NotImplementedError + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. - @abstractmethod - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ raise NotImplementedError - @abstractmethod def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding. """ raise NotImplementedError - @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError - @abstractmethod def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError - @abstractmethod def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError - @abstractmethod def list_loras(self) -> Set[int]: raise NotImplementedError + @property + def vocab_size(self) -> int: + """Get vocabulary size from model configuration.""" + return self.model_config.get_vocab_size() + class DelegateWorkerBase(WorkerBase): """ From b400e14cc666996f6af3c82b163563af8dba186f Mon Sep 17 00:00:00 2001 From: Aoyu Date: Tue, 11 Feb 2025 08:54:38 +0000 Subject: [PATCH 02/13] improve readability by adding warn_unimplemented_methods() Signed-off-by: Aoyu --- vllm/worker/worker_base.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ae7763452bde8..c0c64826cc6fb 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -28,6 +28,8 @@ logger = init_logger(__name__) +T = TypeVar('T') + def check_implementation(): """ @@ -36,12 +38,11 @@ def check_implementation(): methods. """ - def decorator(cls: Type): + def decorator(cls: Type[T]) -> Type[T]: + original_init = cls.__init__ - @wraps(original_init) - def wrapped_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) + def warn_unimplemented_methods(self: object): unimplemented_methods = [] for attr_name in dir(self): # bypass inner method @@ -56,26 +57,28 @@ def wrapped_init(self, *args, **kwargs): base_method_name = base_method.__func__ else: continue - class_method = getattr(cls, attr_name, False) + class_method_name = getattr(cls, attr_name, False) # bypass method defined in sub class - if not class_method: + if not class_method_name: continue - if class_method == base_method_name: + if class_method_name == base_method_name: unimplemented_methods.append(attr_name) if unimplemented_methods: method_names = ','.join(unimplemented_methods) msg = (f"Methods {method_names} not implemented in {self}") logger.warning(msg) + @wraps(original_init) + def wrapped_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + warn_unimplemented_methods(self) + cls.__init__ = wrapped_init return cls return decorator -T = TypeVar('T') - - def avoid_check(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) From c43e6bc9bba741d8904cc130fa513388fcf57dba Mon Sep 17 00:00:00 2001 From: Aoyu Date: Tue, 11 Feb 2025 09:29:26 +0000 Subject: [PATCH 03/13] fix for mypy [method-assign] error Signed-off-by: Aoyu --- vllm/worker/worker_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c0c64826cc6fb..6fc4ea453f8d9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -69,11 +69,11 @@ def warn_unimplemented_methods(self: object): logger.warning(msg) @wraps(original_init) - def wrapped_init(self, *args, **kwargs): + def wrapped_init(self, *args, **kwargs) -> None: original_init(self, *args, **kwargs) warn_unimplemented_methods(self) - cls.__init__ = wrapped_init + type.__setattr__(cls, '__init__', wrapped_init) return cls return decorator From 8e7906db5fd90335133de289cd90bac33484c15e Mon Sep 17 00:00:00 2001 From: Aoyu Date: Wed, 12 Feb 2025 02:44:40 +0000 Subject: [PATCH 04/13] refactor for easy maintenance Signed-off-by: Aoyu --- vllm/utils.py | 59 +++++++++++++++++++++++++++ vllm/v1/worker/worker_base.py | 7 ++-- vllm/worker/worker_base.py | 75 +++-------------------------------- 3 files changed, 68 insertions(+), 73 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 8b92695987573..669600e6bc7b8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2270,3 +2270,62 @@ def import_pynvml(): sys.modules["pynvml"] = pynvml spec.loader.exec_module(pynvml) return pynvml + + +def warn_for_unimplemented_methods(): + """ + A decorator that checks if all abstract methods from the base class + are implemented in the subclass and gives warnings for unimplemented + methods. + """ + + def decorator(cls: Type[T]) -> Type[T]: + + original_init = cls.__init__ + + def warn_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith('_'): + continue + base_method = getattr(self, attr_name) + # bypass method already defined + if getattr(base_method, '_avoid_check', False): + continue + # get the func of callable method + if callable(base_method): + base_method_name = base_method.__func__ + else: + continue + class_method_name = getattr(cls, attr_name, False) + # bypass method defined in sub class + if not class_method_name: + continue + if class_method_name == base_method_name: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ','.join(unimplemented_methods) + msg = (f"Methods {method_names} not implemented in {self}") + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + warn_unimplemented_methods(self) + + type.__setattr__(cls, '__init__', wrapped_init) + return cls + + return decorator + + +def avoid_warn_for_unimplementation( + func: Callable[..., T]) -> Callable[..., T]: + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + return func(*args, **kwargs) + + wrapper._avoid_check = True # type: ignore + return wrapper diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 9d11da93b5e93..9816afb70a14a 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from abc import abstractmethod from typing import Optional import torch @@ -8,12 +7,14 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.utils import warn_for_unimplemented_methods from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 logger = init_logger(__name__) +@warn_for_unimplemented_methods() class WorkerBase(WorkerBaseV0): """ Abstract class for v1 worker, mainly define some methods for v1. @@ -51,16 +52,14 @@ def __init__( self.device: Optional[torch.device] = None self.model_runner: Optional[nn.Module] = None - @abstractmethod def get_kv_cache_spec(self) -> KVCacheSpec: """Get specifications for KV cache implementation.""" raise NotImplementedError - @abstractmethod def compile_or_warm_up_model(self) -> None: """Prepare model for execution through compilation/warmup.""" raise NotImplementedError def check_health(self) -> None: """Basic health check (override for device-specific checks).""" - return \ No newline at end of file + return diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6fc4ea453f8d9..a5066ab78e656 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -3,10 +3,8 @@ import dataclasses import os import time -from abc import ABC, abstractmethod -from functools import wraps -from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import cloudpickle import torch @@ -19,7 +17,8 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import (enable_trace_function_call_for_thread, +from vllm.utils import (avoid_warn_for_unimplementation, + enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, @@ -28,69 +27,8 @@ logger = init_logger(__name__) -T = TypeVar('T') - -def check_implementation(): - """ - A decorator that checks if all abstract methods from the base class - are implemented in the subclass and gives warnings for unimplemented - methods. - """ - - def decorator(cls: Type[T]) -> Type[T]: - - original_init = cls.__init__ - - def warn_unimplemented_methods(self: object): - unimplemented_methods = [] - for attr_name in dir(self): - # bypass inner method - if attr_name.startswith('_'): - continue - base_method = getattr(self, attr_name) - # bypass method already defined - if getattr(base_method, '_avoid_check', False): - continue - # get the func of callable method - if callable(base_method): - base_method_name = base_method.__func__ - else: - continue - class_method_name = getattr(cls, attr_name, False) - # bypass method defined in sub class - if not class_method_name: - continue - if class_method_name == base_method_name: - unimplemented_methods.append(attr_name) - if unimplemented_methods: - method_names = ','.join(unimplemented_methods) - msg = (f"Methods {method_names} not implemented in {self}") - logger.warning(msg) - - @wraps(original_init) - def wrapped_init(self, *args, **kwargs) -> None: - original_init(self, *args, **kwargs) - warn_unimplemented_methods(self) - - type.__setattr__(cls, '__init__', wrapped_init) - return cls - - return decorator - - -def avoid_check(func: Callable[..., T]) -> Callable[..., T]: - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> T: - return func(*args, **kwargs) - - wrapper._avoid_check = True # type: ignore - return wrapper - - -@check_implementation() -class WorkerBase(ABC): +class WorkerBase: """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to communicate request metadata to other workers. @@ -116,7 +54,6 @@ def __init__( from vllm.platforms import current_platform self.current_platform = current_platform - @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device memory allocations. @@ -142,7 +79,7 @@ def execute_model( ) -> Optional[List[SamplerOutput]]: raise NotImplementedError - @avoid_check + @avoid_warn_for_unimplementation def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. From 4f86570426bc8b23e06131133ec68fa0b52e2ead Mon Sep 17 00:00:00 2001 From: Aoyu Date: Wed, 12 Feb 2025 09:08:16 +0000 Subject: [PATCH 05/13] remove _avoid_check Signed-off-by: Aoyu --- vllm/utils.py | 11 ----------- vllm/v1/worker/worker_base.py | 2 -- vllm/worker/worker_base.py | 8 ++++---- 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index ac419cdfcb748..dce660a0d01ee 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2301,14 +2301,3 @@ def wrapped_init(self, *args, **kwargs) -> None: return cls return decorator - - -def avoid_warn_for_unimplementation( - func: Callable[..., T]) -> Callable[..., T]: - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> T: - return func(*args, **kwargs) - - wrapper._avoid_check = True # type: ignore - return wrapper diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 9816afb70a14a..bc7e76c38aed3 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -7,14 +7,12 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import warn_for_unimplemented_methods from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 logger = init_logger(__name__) -@warn_for_unimplemented_methods() class WorkerBase(WorkerBaseV0): """ Abstract class for v1 worker, mainly define some methods for v1. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a5066ab78e656..6de3f563414b9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -17,10 +17,10 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import (avoid_warn_for_unimplementation, - enable_trace_function_call_for_thread, +from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, - update_environment_variables) + update_environment_variables, + warn_for_unimplemented_methods) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -28,6 +28,7 @@ logger = init_logger(__name__) +@warn_for_unimplemented_methods() class WorkerBase: """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to @@ -79,7 +80,6 @@ def execute_model( ) -> Optional[List[SamplerOutput]]: raise NotImplementedError - @avoid_warn_for_unimplementation def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. From 86d0705c0ee29d1d6d350c33e7cdd94f9df37f3d Mon Sep 17 00:00:00 2001 From: Aoyu Date: Wed, 12 Feb 2025 11:45:16 +0000 Subject: [PATCH 06/13] fix per-commit Signed-off-by: Aoyu --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index dce660a0d01ee..9c0fcb5aae1d8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2224,7 +2224,6 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], return func(*args, **kwargs) - def import_pynvml(): """ Historical comments: @@ -2255,6 +2254,7 @@ def import_pynvml(): import vllm.third_party.pynvml as pynvml return pynvml + def warn_for_unimplemented_methods(): """ A decorator that checks if all abstract methods from the base class From cb8a099aea091f10e2f92fa1b1415c99c8e5eb2d Mon Sep 17 00:00:00 2001 From: Aoyu Date: Thu, 13 Feb 2025 03:20:34 +0000 Subject: [PATCH 07/13] remove avoid_check Signed-off-by: Aoyu --- vllm/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 9c0fcb5aae1d8..73b76dc949162 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2273,9 +2273,6 @@ def warn_unimplemented_methods(self: object): if attr_name.startswith('_'): continue base_method = getattr(self, attr_name) - # bypass method already defined - if getattr(base_method, '_avoid_check', False): - continue # get the func of callable method if callable(base_method): base_method_name = base_method.__func__ From 889c72a5f594c2e909c043fd96370be0bf4973e7 Mon Sep 17 00:00:00 2001 From: Aoyu Date: Thu, 13 Feb 2025 03:28:29 +0000 Subject: [PATCH 08/13] check NotImplementedError directly Signed-off-by: Aoyu --- vllm/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 73b76dc949162..a141cbbbfb52c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2278,11 +2278,8 @@ def warn_unimplemented_methods(self: object): base_method_name = base_method.__func__ else: continue - class_method_name = getattr(cls, attr_name, False) - # bypass method defined in sub class - if not class_method_name: - continue - if class_method_name == base_method_name: + src = inspect.getsource(base_method_name) + if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: method_names = ','.join(unimplemented_methods) From a68c0fe381b638083f0de2d6391dacf74e2d796d Mon Sep 17 00:00:00 2001 From: Aoyu Date: Thu, 13 Feb 2025 04:10:30 +0000 Subject: [PATCH 09/13] remove decorator inside warn_for func Signed-off-by: Aoyu --- vllm/utils.py | 74 ++++++++++++++++++-------------------- vllm/worker/worker_base.py | 2 +- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index a141cbbbfb52c..517e4167659e4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2255,43 +2255,37 @@ def import_pynvml(): return pynvml -def warn_for_unimplemented_methods(): - """ - A decorator that checks if all abstract methods from the base class - are implemented in the subclass and gives warnings for unimplemented - methods. - """ - - def decorator(cls: Type[T]) -> Type[T]: - - original_init = cls.__init__ - - def warn_unimplemented_methods(self: object): - unimplemented_methods = [] - for attr_name in dir(self): - # bypass inner method - if attr_name.startswith('_'): - continue - base_method = getattr(self, attr_name) - # get the func of callable method - if callable(base_method): - base_method_name = base_method.__func__ - else: - continue - src = inspect.getsource(base_method_name) - if "NotImplementedError" in src: - unimplemented_methods.append(attr_name) - if unimplemented_methods: - method_names = ','.join(unimplemented_methods) - msg = (f"Methods {method_names} not implemented in {self}") - logger.warning(msg) - - @wraps(original_init) - def wrapped_init(self, *args, **kwargs) -> None: - original_init(self, *args, **kwargs) - warn_unimplemented_methods(self) - - type.__setattr__(cls, '__init__', wrapped_init) - return cls - - return decorator +def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]: + + original_init = cls.__init__ + + def find_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith('_'): + continue + base_method = getattr(self, attr_name) + # get the func of callable method + if callable(base_method): + base_method_name = base_method.__func__ + else: + continue + class_method_name = getattr(cls, attr_name, False) + # bypass method defined in sub class + if not class_method_name: + continue + if class_method_name == base_method_name: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ','.join(unimplemented_methods) + msg = (f"Methods {method_names} not implemented in {self}") + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + find_unimplemented_methods(self) + + type.__setattr__(cls, '__init__', wrapped_init) + return cls diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6de3f563414b9..aca522b9f66ad 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -28,7 +28,7 @@ logger = init_logger(__name__) -@warn_for_unimplemented_methods() +@warn_for_unimplemented_methods class WorkerBase: """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to From 47ceccf81081a2a78bea4a5747c43f86d9226ad3 Mon Sep 17 00:00:00 2001 From: Aoyu Date: Thu, 13 Feb 2025 05:31:00 +0000 Subject: [PATCH 10/13] add NotImplementedError check Signed-off-by: Aoyu --- vllm/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 517e4167659e4..218c94d68b888 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2271,11 +2271,8 @@ def find_unimplemented_methods(self: object): base_method_name = base_method.__func__ else: continue - class_method_name = getattr(cls, attr_name, False) - # bypass method defined in sub class - if not class_method_name: - continue - if class_method_name == base_method_name: + src = inspect.getsource(base_method_name) + if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: method_names = ','.join(unimplemented_methods) From c5b222c06136f39e03f1df7415536ae2e68f7824 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 13 Feb 2025 13:31:32 +0800 Subject: [PATCH 11/13] use getsource Signed-off-by: youkaichao --- vllm/utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 517e4167659e4..9f4b70c128a1d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2256,6 +2256,14 @@ def import_pynvml(): def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]: + """ + A replacement for `abc.ABC`. + When we use `abc.ABC`, subclasses will fail to instantiate + if they do not implement all abstract methods. + Here, we only require `raise NotImplementedError` in the + base class, and log a warning if the method is not implemented + in the subclass. + """ original_init = cls.__init__ @@ -2271,11 +2279,8 @@ def find_unimplemented_methods(self: object): base_method_name = base_method.__func__ else: continue - class_method_name = getattr(cls, attr_name, False) - # bypass method defined in sub class - if not class_method_name: - continue - if class_method_name == base_method_name: + src = inspect.getsource(base_method_name) + if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: method_names = ','.join(unimplemented_methods) From 84b2dabfa1da731131828cf9091536c2c4580fcc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 13 Feb 2025 13:40:37 +0800 Subject: [PATCH 12/13] catch error Signed-off-by: youkaichao --- vllm/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 9f4b70c128a1d..cf960140ecabb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2273,13 +2273,15 @@ def find_unimplemented_methods(self: object): # bypass inner method if attr_name.startswith('_'): continue - base_method = getattr(self, attr_name) - # get the func of callable method - if callable(base_method): - base_method_name = base_method.__func__ - else: + + try: + attr = getattr(self, attr_name) + # get the func of callable method + if callable(attr): + attr_func = attr.__func__ + except AttributeError: continue - src = inspect.getsource(base_method_name) + src = inspect.getsource(attr_func) if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: From 793fe42fe0d41a17f02f5a73c846875732e1a963 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 13 Feb 2025 16:36:25 +0800 Subject: [PATCH 13/13] fix spec decode Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index aca522b9f66ad..83fcf0865ae1c 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -157,6 +157,10 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def load_model(self) -> None: + """Load model onto target device.""" + self.worker.load_model() + def get_model(self) -> nn.Module: return self.worker.get_model()