Skip to content

Commit

Permalink
Allow to use flex_attention instead of FSDPA in HPUAttentionImpl (#876)
Browse files Browse the repository at this point in the history
Co-authored-by: Michał Kuligowski <[email protected]>
  • Loading branch information
m-a-nowak and michalkuligowski authored Mar 7, 2025
1 parent 03df014 commit 34ba9ed
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
1 change: 1 addition & 0 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`

# Quantization, FP8 Inference and Model Calibration Process

Expand Down
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3fd0250
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@ecb60e4
44 changes: 28 additions & 16 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
# Copyright (C) 2024-2025 Habana Labs, Ltd. an Intel Company
###############################################################################

from dataclasses import dataclass
Expand Down Expand Up @@ -159,6 +159,8 @@ def __init__(
logger().warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

self.prefill_use_flex_attention = "flex_attention" in enabled_flags()

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -237,7 +239,8 @@ def forward(
self.head_size)

if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
if (not self.prefill_use_fusedsdpa
and not self.prefill_use_flex_attention):
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
Expand All @@ -252,20 +255,29 @@ def forward(
else:
attn_bias = attn_metadata.attn_bias

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention
if self.prefill_use_fusedsdpa else None,
)
if not self.prefill_use_flex_attention:
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention
if self.prefill_use_fusedsdpa else None,
)
else:
out = ops.flex_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
scale=self.scale,
)

else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
Expand Down

0 comments on commit 34ba9ed

Please sign in to comment.