Skip to content

Commit

Permalink
[bugfix] fix early import of flash attention (#12959)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Feb 8, 2025
1 parent 913df14 commit fe743b7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
13 changes: 7 additions & 6 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
compute_slot_mapping, compute_slot_mapping_start_idx,
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
is_block_tables_empty)
Expand Down Expand Up @@ -640,6 +640,7 @@ def __init__(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()

def forward(
self,
Expand Down Expand Up @@ -759,7 +760,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
else:
# prefix-enabled attention
Expand All @@ -782,7 +783,7 @@ def forward(
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down Expand Up @@ -811,7 +812,7 @@ def forward(
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -832,7 +833,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output

Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -181,6 +181,7 @@ def __init__(
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
Expand Down Expand Up @@ -515,7 +516,7 @@ def _forward_prefill_flash(
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
Expand Down
14 changes: 6 additions & 8 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens(
num_decode_query_tokens)


try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
def get_flash_attn_version():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)

def flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
Expand All @@ -610,7 +610,5 @@ def flash_attn_version():

assert is_fa_version_supported(fa_version)
return fa_version

VLLM_FLASH_ATTN_VERSION = flash_attn_version()
except (ImportError, AssertionError):
VLLM_FLASH_ATTN_VERSION = None
except (ImportError, AssertionError):
return None
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -132,6 +132,7 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()

def forward(
self,
Expand Down Expand Up @@ -205,7 +206,7 @@ def forward(
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output

Expand All @@ -227,7 +228,7 @@ def forward(
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output

Expand Down

0 comments on commit fe743b7

Please sign in to comment.