diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index bf3992281a735..bbbdf50ac0cc7 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -12,8 +12,8 @@ from vllm.attention.backends.utils import CommonAttentionState # These are the 2 tunable parameters of the paged attention Pallas kernel. -NUM_QUERIES_PER_BLOCK = 16 -NUM_KV_PAGES_PER_BLOCK = 256 +NUM_QUERIES_PER_BLOCK = 32 +NUM_KV_PAGES_PER_BLOCK = 128 class PallasAttentionBackend(AttentionBackend): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d4ebb3adcf8dc..d564a964a37fa 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -22,9 +22,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, - NUM_QUERIES_PER_BLOCK, - PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -77,10 +75,8 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = _get_padded_number( - scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK) - self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs, - NUM_QUERIES_PER_BLOCK) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -141,16 +137,8 @@ def __init__( device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() - # self.input_batch.block_table has a shape of [max_num_reqs, - # max_num_blocks_per_req]. To reduce the number of recompilation, - # we want the block_table.shape[0] to be num_tokens. - # To make the block_table to be compatible with the paged attention - # kernel, we want the block_table[1] to be multiple of - # NUM_KV_PAGES_PER_BLOCK. - padded_max_num_blocks_per_req = _get_padded_number( - self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros( - (self.max_num_tokens, padded_max_num_blocks_per_req), + (self.max_num_tokens, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu")