Skip to content

Commit 0aecd17

Browse files
jianyuhfacebook-github-bot
authored andcommitted
Dedup GQA splitk kernel
Summary: We want to keep use_tensor_cores = False option for gqa_attn_splitk function for backward compatibility (GPUs before Hopper, AMD). Reviewed By: sryap Differential Revision: D56687037 fbshipit-source-id: 0c98fe6327fd063b62d59aaaacd238cacbfb20c5
1 parent 0da2f0c commit 0aecd17

File tree

2 files changed

+983
-38
lines changed

2 files changed

+983
-38
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212

1313
namespace fbgemm_gpu::gen_ai::attention {
1414

15-
std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_cuda(
15+
std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(
1616
const at::Tensor& XQ,
1717
const at::Tensor& cache_K,
1818
const at::Tensor& cache_V,
1919
const at::Tensor& seq_positions,
2020
const double qk_scale,
2121
const int64_t num_split_ks,
22-
const int64_t num_groups);
23-
22+
const int64_t num_int4_kv_groups,
23+
const bool use_tensor_cores);
2424
} // namespace fbgemm_gpu::gen_ai::attention
2525

2626
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
@@ -32,7 +32,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
3232
" Tensor seq_positions, "
3333
" float qk_scale, "
3434
" int num_split_ks, "
35-
" int num_int4_kv_groups=1"
35+
" int num_int4_kv_groups=1, "
36+
" bool use_tensor_cores=True"
3637
") -> (Tensor, Tensor, Tensor)");
3738
}
3839

@@ -41,5 +42,5 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
4142
"gqa_attn_splitk",
4243
torch::dispatch(
4344
c10::DispatchKey::CUDA,
44-
TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk_cuda)));
45+
TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk)));
4546
}

0 commit comments

Comments
 (0)