forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1][Core] Add worker_base for v1 worker (vllm-project#12816)
Signed-off-by: Aoyu <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: Aoyu <[email protected]> Co-authored-by: youkaichao <[email protected]> Signed-off-by: Linkun Chen <[email protected]>
- Loading branch information
Showing
4 changed files
with
153 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.v1.kv_cache_interface import KVCacheSpec | ||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class WorkerBase(WorkerBaseV0): | ||
""" | ||
Abstract class for v1 worker, mainly define some methods for v1. | ||
For methods shared by v0 and v1, define them in v0 WorkerBase | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vllm_config: VllmConfig, | ||
local_rank: int, | ||
rank: int, | ||
distributed_init_method: str, | ||
is_driver_worker: bool = False, | ||
): | ||
""" | ||
Initialize common worker components. | ||
Args: | ||
vllm_config: Complete vLLM configuration | ||
local_rank: Local device index | ||
rank: Global rank in distributed setup | ||
distributed_init_method: Distributed initialization method | ||
is_driver_worker: Whether this worker handles driver | ||
responsibilities | ||
""" | ||
# Configuration storage | ||
super().__init__(vllm_config=vllm_config) | ||
|
||
self.local_rank = local_rank | ||
self.rank = rank | ||
self.distributed_init_method = distributed_init_method | ||
self.is_driver_worker = is_driver_worker | ||
|
||
# Device and model state | ||
self.device: Optional[torch.device] = None | ||
self.model_runner: Optional[nn.Module] = None | ||
|
||
def get_kv_cache_spec(self) -> KVCacheSpec: | ||
"""Get specifications for KV cache implementation.""" | ||
raise NotImplementedError | ||
|
||
def compile_or_warm_up_model(self) -> None: | ||
"""Prepare model for execution through compilation/warmup.""" | ||
raise NotImplementedError | ||
|
||
def check_health(self) -> None: | ||
"""Basic health check (override for device-specific checks).""" | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters