Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output Discrepancy Between FlashAttention and PyTorch Attention #1359

Closed
pengzhangzhi opened this issue Nov 27, 2024 · 2 comments
Closed

Output Discrepancy Between FlashAttention and PyTorch Attention #1359

pengzhangzhi opened this issue Nov 27, 2024 · 2 comments

Comments

@pengzhangzhi
Copy link

I recently benchmarked FlashAttention against PyTorch’s scaled_dot_product_attention using a custom script and observed a significant discrepancy in the output. Below are the details of the issue:

Benchmark Summary:

•	Setup:
•	PyTorch Version:  2.5.0+cu121
•	FlashAttention Version:  2.7.0.post2
•	Device: NVIDIA GPU A100
•	Data Type: torch.float16
•	Observation:
•	Max Absolute Difference: 5.265625
•	Mean Absolute Difference: 0.7984185814857483

Steps to Reproduce:

Here is the benchmark code I used for testing:

import torch
import torch.nn.functional as F
import math

def PyTorchAttention(query_layer, key_layer, value_layer, attention_mask):
    """
    Attention implementation using PyTorch's scaled_dot_product_attention.
    """
    B, H, L, D = query_layer.shape
    # Compute attention using PyTorch's built-in function
    context_layer = F.scaled_dot_product_attention(
        query_layer, key_layer, value_layer,
        attn_mask=attention_mask,
        is_causal=False,
        scale=1,
    )
    # Rearrange and reshape the context layer
    context_layer = context_layer.reshape(B, L, H*D)
    return context_layer

def FlashAttention(query_layer, key_layer, value_layer, attention_mask):
    """
    Attention implementation using FlashAttention.
    """
    B, H, L, D = query_layer.shape
    qkv = torch.stack((query_layer, key_layer, value_layer), dim=1).reshape(B, L, 3, H, D)
    from flash_attn.bert_padding import pad_input, unpad_input
    from flash_attn import flash_attn_varlen_qkvpacked_func
    # Unpad the input sequences based on the attention mask
    qkv_unpadded, indices, cu_seqlens, max_seqlen, _ = unpad_input(
        hidden_states=qkv, attention_mask=attention_mask.reshape(B, L)
    )
    # Apply FlashAttention
    fa_out = flash_attn_varlen_qkvpacked_func(
        qkv_unpadded,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
        softmax_scale=1.0,
        causal=False,
    )
    # Pad the output back to the original sequence length
    fa_out = pad_input(fa_out, indices, B, L).to(torch.float32)
    # Reshape and rearrange the output to match the expected shape
    fa_out = fa_out.reshape(B, L, H*D)
 
    return fa_out

def compare_attention_implementations(batch_size, num_heads, seq_length, head_dim, dtype=torch.float16):
    """
    Generates random inputs and compares the outputs of the two attention implementations.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Generate random inputs
    query_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
    key_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)
    value_layer = torch.randn(batch_size, num_heads, seq_length, head_dim, dtype=dtype, device=device, requires_grad=False)

    # Generate a random attention mask
    attention_mask = torch.randint(0, 2, (batch_size, seq_length), dtype=torch.bool, device=device)
    
    attn_mask = attention_mask[:, None, None, :]  # [batch_size, 1, 1, seq_length]

    # Compute outputs from both implementations
    output1 = PyTorchAttention(query_layer, key_layer, value_layer, attn_mask)
    output2 = FlashAttention(query_layer, key_layer, value_layer, attention_mask)

    # Compare outputs
    difference = torch.abs(output1 - output2)
    max_diff = difference.max().item()
    mean_diff = difference.mean().item()

    print(f'Data type: {dtype}')
    print(f'Max absolute difference: {max_diff}')
    print(f'Mean absolute difference: {mean_diff}')

def seed_all(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
    # Set parameters
    batch_size = 2
    num_heads = 4
    seq_length = 128
    head_dim = 64  # Dimension per head
    dtype = torch.float16  # Change to torch.float32 or torch.float64 as needed
    seed_all(42)
    # Compare implementations
    from torch.nn.attention import sdpa_kernel, SDPBackend
    with sdpa_kernel(backends=[SDPBackend.MATH]):
        compare_attention_implementations(batch_size, num_heads, seq_length, head_dim, dtype)

The discrepancy is consistent across multiple runs, even with deterministic settings (e.g., fixed seeds).

Questions and Notes:

1.	Is this discrepancy expected under certain conditions (e.g., data type torch.float16, attention mask handling)?
2.	If not, could this indicate a potential bug or precision issue in the my benchmark implementation?

Thank you for your time and for developing such an efficient attention mechanism! Please let me know if you need further details or additional benchmarks.

Best regards,

@tridao
Copy link
Member

tridao commented Nov 27, 2024

This reshape is not what you want

qkv = torch.stack((query_layer, key_layer, value_layer), dim=1).reshape(B, L, 3, H, D)

Please see the docstring or tests here (

def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
) to see what input layout the function expects.

@pengzhangzhi
Copy link
Author

Hey Tri,

Thanks for taking time reponding me! turns out it's a minor oversight. I should zero out the pad tokens from pytorch SDPA.

Thanks for the great work! I re-implement a protein langugae model with flash attention, a huge speed up and memory save.

Cheers,

Fred

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants