Skip to content

Commit

Permalink
[V1][Metrics] Add GPU prefix cache hit rate % gauge (vllm-project#12592)
Browse files Browse the repository at this point in the history
Signed-off-by: saeediy <[email protected]>
  • Loading branch information
comaniac authored and Said-Akbar committed Mar 7, 2025
1 parent 1cb3ce1 commit 83439a1
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 5 deletions.
2 changes: 2 additions & 0 deletions tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
"vllm:num_requests_running",
"vllm:num_requests_waiting",
"vllm:gpu_cache_usage_perc",
"vllm:gpu_prefix_cache_queries",
"vllm:gpu_prefix_cache_hits",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
"vllm:request_success_total",
Expand Down
39 changes: 38 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
KVCacheBlock, PrefixCachingMetrics,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request


Expand Down Expand Up @@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs():
assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None


def test_metrics():
"""
Test the prefix caching metrics.
"""

def stats(requests, queries, hits):
return PrefixCacheStats(requests=requests, queries=queries, hits=hits)

metrics = PrefixCachingMetrics(interval=5)
assert metrics.hit_rate == 0.0

metrics.observe(stats(1, 20, 9))
# 9 / 20 = 0.45
assert metrics.hit_rate == 0.45

metrics.observe(stats(4, 80, 16))

# 25 / 100 = 0.25
assert metrics.hit_rate == 0.25

metrics.observe(stats(1, 10, 2))

# Remove (20, 9) and add (10, 2): 18 / 90 = 0.2
assert metrics.aggregated_requests == 5
assert metrics.aggregated_query_total == 90
assert metrics.aggregated_query_hit == 18
assert metrics.hit_rate == 0.2

metrics.reset()
assert metrics.hit_rate == 0.0
assert metrics.aggregated_requests == 0
assert metrics.aggregated_query_total == 0
assert metrics.aggregated_query_hit == 0
assert not metrics.query_queue
24 changes: 24 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus

logger = init_logger(__name__)
Expand Down Expand Up @@ -78,11 +79,28 @@ def __init__(
self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list)

self.prefix_cache_stats = PrefixCacheStats()

@property
def usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.free_block_queue.num_free_blocks /
self.num_gpu_blocks)

def make_prefix_cache_stats(self) -> PrefixCacheStats:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats.
"""
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats

def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request.
Expand Down Expand Up @@ -118,6 +136,10 @@ def get_computed_blocks(
else:
break

self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)

# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
Expand Down Expand Up @@ -280,6 +302,8 @@ def reset_prefix_cache(self) -> bool:
for block in self.block_pool:
block.reset_hash()

self.prefix_cache_stats.reset = True

logger.info("Successfully reset prefix cache")
return True

Expand Down
64 changes: 64 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities."""
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple
Expand All @@ -8,6 +9,7 @@
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request

logger = init_logger(__name__)
Expand All @@ -28,6 +30,68 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None


class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""

def __init__(self, interval: int = 1000):
self.interval = interval
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[Tuple[int, int, int]] = deque()

def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `interval` requests, the oldest set of
requestsare removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()

# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits

# Remove the oldest stats if the number of requests exceeds.
if self.aggregated_requests > self.interval:
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits

def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()

@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total


@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,4 +593,5 @@ def make_stats(self) -> SchedulerStats:
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
)
29 changes: 27 additions & 2 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats

Expand Down Expand Up @@ -37,6 +38,9 @@ def _reset(self, now):
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []

# Prefix cache metrics. TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()

def _local_interval_elapsed(self, now: float) -> bool:
# Log every _LOCAL_LOGGING_INTERVAL_SEC.
elapsed_time = now - self.last_log_time
Expand All @@ -58,6 +62,8 @@ def log(self, scheduler_stats: SchedulerStats,

self._track_iteration_stats(iteration_stats)

self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)

now = time.monotonic()
if not self._local_interval_elapsed(now):
return
Expand All @@ -72,13 +78,15 @@ def log(self, scheduler_stats: SchedulerStats,
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs "
"GPU KV cache usage: %.1f%%.",
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.gpu_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)


Expand Down Expand Up @@ -107,6 +115,18 @@ def __init__(self, model_config: ModelConfig):
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames).labels(*labelvalues)

self.counter_gpu_prefix_cache_queries = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_queries",
documentation=
"GPU prefix cache queries, in terms of number of queried blocks.",
labelnames=labelnames).labels(*labelvalues)

self.counter_gpu_prefix_cache_hits = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_hits",
documentation=
"GPU prefix cache hits, in terms of number of cached blocks.",
labelnames=labelnames).labels(*labelvalues)

self.counter_prompt_tokens = prometheus_client.Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
Expand Down Expand Up @@ -170,6 +190,11 @@ def log(self, scheduler_stats: SchedulerStats,

self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)

self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)

self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(
iteration_stats.num_generation_tokens)
Expand Down
20 changes: 18 additions & 2 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# SPDX-License-Identifier: Apache-2.0

import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List

if TYPE_CHECKING:
from vllm.outputs import RequestOutput
from vllm.v1.engine import EngineCoreOutput, FinishReason


@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of blocks that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0


@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
Expand All @@ -17,7 +31,9 @@ class SchedulerStats:
num_waiting_reqs: int = 0

gpu_cache_usage: float = 0.0
# gpu_prefix_cache_hit_rate: float = 0.0

prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)


@dataclass
Expand Down

0 comments on commit 83439a1

Please sign in to comment.