-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2-Pass Sdpa Inference Kernel #1597
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing speedup!! LGTM!
keys += blocks * stride; | ||
values += blocks * stride; | ||
} | ||
threadgroup_barrier(mem_flags::mem_threadgroup); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That barrier may be uneccessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually a great catch. It should be inside the loop. Same for the 1 pass. The reasoning is it makes sure the thread group all reads the block at the same time so one simdgroup cannot just run ahead. I had seen it provides a small improvement but then in one of the edits it probably got restored back outside the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I take it back... I tested again and I get mixed results which is probably why I reverted it in some previous commit. I will remove it since indeed conceptually it needn't be there.
This PR aims to improve long context generation performance by increasing parallelization for large numbers of keys/values. There are mild benefits for smaller machines and very significant benefits for Ultra machines.
The main benefit for small machines stems from accessing the keys and values in a more cache friendly way when there is GQA and for the Ultra machines it stems from launching more thread groups which allows using more of the chip.
Speedup for M2 Max
The following speedup is in total tokens per second and not attention speedup. Note the phi model which does not improve does not have GQA. The 1 pass SDPA on the M2 Max achieves ~350 to 380 GB/s read for sequence length ~2048 so there isn't really much room for speedup.
Speedup for M2 Ultra
Again the speedup is in total tokens per second and not attention specific. The M2 Ultra is sped up for all cases, no GQA required. The 2048 sequence length without GQA peaks at >800GB/s which also means there is probably little room for improvement (there could be for longer sequences).