diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5a26ccafb46d7..660f26c71c381 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -621,11 +621,25 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: output = torch.empty_like(query) - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = _get_seq_len_block_table_args( - prefill_meta, attn_type) + # Prompt run. + # normal attention and DECODER + if attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, attn_type) # Prompt run. if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: # triton attention