-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update custom PA kernel with support for fp8 kv cache dtype #87
Conversation
…ustom PA partition size to 512 to prefer throughput scenarios at cost of latency
@@ -20,6 +22,20 @@ typedef float16x4 _Half4; | |||
typedef struct _Half8 { | |||
_Half4 xy[2]; | |||
} _Half8; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below types defined to unify float16 and bfloat16?
@@ -15,7 +15,7 @@ | |||
|
|||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. | |||
_PARTITION_SIZE_V1V2 = 512 | |||
_PARTITION_SIZE_CUSTOM = 256 | |||
_PARTITION_SIZE_CUSTOM = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to keep this different def now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ready to ship after some changes thanks to @mawong-amd
Have updated this on top of latest main and also enabled FP8 KV cache on BF16. |
@mawong-amd I'll take a quick look. |
* update custom PA kernel with support for fp8 kv cache dtype; change custom PA partition size to 512 to prefer throughput scenarios at cost of latency * Fix lint * Fix BF16 with FP8 KV cache (scaled conversion incorrectly done in fp16) --------- Co-authored-by: Matthew Wong <[email protected]>
custom PA kernel support added for fp8 kv cache dtype.
Change custom PA partition size to 512 to prefer throughput scenarios at cost of latency.
Unit Tested for f16 and fp8 kv cache, num_qheads=64 num_kvheads=8 only.
e2e tested on mlperf offline benchmark.
Unit tests to be updated in separate PR.