You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I recently benchmarked FlashAttention against PyTorch’s scaled_dot_product_attention using a custom script and observed a significant discrepancy in the output. Below are the details of the issue:
import torch
import torch.nn.functional as F
import math
def PyTorchAttention(query_layer, key_layer, value_layer, attention_mask):
"""
Attention implementation using PyTorch's scaled_dot_product_attention.
"""
B, H, L, D = query_layer.shape
# Compute attention using PyTorch's built-in function
context_layer = F.scaled_dot_product_attention(
query_layer, key_layer, value_layer,
attn_mask=attention_mask,
is_causal=False,
scale=1,
)
# Rearrange and reshape the context layer
context_layer = context_layer.reshape(B, L, H*D)
return context_layer
def FlashAttention(query_layer, key_layer, value_layer, attention_mask):
"""
Attention implementation using FlashAttention.
"""
B, H, L, D = query_layer.shape
qkv = torch.stack((query_layer, key_layer, value_layer), dim=1).reshape(B, L, 3, H, D)
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_qkvpacked_func
# Unpad the input sequences based on the attention mask
qkv_unpadded, indices, cu_seqlens, max_seqlen, _ = unpad_input(
hidden_states=qkv, attention_mask=attention_mask.reshape(B, L)
)
# Apply FlashAttention
fa_out = flash_attn_varlen_qkvpacked_func(
qkv_unpadded,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
softmax_scale=1.0,
causal=False,
)
# Pad the output back to the original sequence length
fa_out = pad_input(fa_out, indices, B, L).to(torch.float32)
# Reshape and rearrange the output to match the expected shape
fa_out = fa_out.reshape(B, L, H*D)
return fa_out
def compare_attention_implementations(batch_size, num_heads, seq_length, head_dim, dtype=torch.float16):
"""
Generates random inputs and compares the outputs of the two attention implementations.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Generate random inputs
query_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
key_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
value_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
# Generate a random attention mask
attention_mask = torch.randint(0, 2, (batch_size, seq_length), dtype=torch.bool, device=device)
attn_mask = attention_mask[:, None, None, :] # [batch_size, 1, 1, seq_length]
# Compute outputs from both implementations
output1 = PyTorchAttention(query_layer, key_layer, value_layer, attn_mask)
output2 = FlashAttention(query_layer, key_layer, value_layer, attention_mask)
# Compare outputs
difference = torch.abs(output1 - output2)
max_diff = difference.max().item()
mean_diff = difference.mean().item()
print(f'Data type: {dtype}')
print(f'Max absolute difference: {max_diff}')
print(f'Mean absolute difference: {mean_diff}')
def seed_all(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
# Set parameters
batch_size = 2
num_heads = 4
seq_length = 128
head_dim = 64 # Dimension per head
dtype = torch.float16 # Change to torch.float32 or torch.float64 as needed
seed_all(42)
# Compare implementations
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(backends=[SDPBackend.MATH]):
compare_attention_implementations(batch_size, num_heads, seq_length, head_dim, dtype)
The discrepancy is consistent across multiple runs, even with deterministic settings (e.g., fixed seeds).
Questions and Notes:
1. Is this discrepancy expected under certain conditions (e.g., data type torch.float16, attention mask handling)?
2. If not, could this indicate a potential bug or precision issue in the my benchmark implementation?
Thank you for your time and for developing such an efficient attention mechanism! Please let me know if you need further details or additional benchmarks.
Best regards,
The text was updated successfully, but these errors were encountered:
I recently benchmarked FlashAttention against PyTorch’s scaled_dot_product_attention using a custom script and observed a significant discrepancy in the output. Below are the details of the issue:
Benchmark Summary:
Steps to Reproduce:
Here is the benchmark code I used for testing:
The discrepancy is consistent across multiple runs, even with deterministic settings (e.g., fixed seeds).
Questions and Notes:
Thank you for your time and for developing such an efficient attention mechanism! Please let me know if you need further details or additional benchmarks.
Best regards,
The text was updated successfully, but these errors were encountered: