-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve mixed dtype GEMM #1972
Improve mixed dtype GEMM #1972
Conversation
Just out of curiosity, how much does it improve latency on which cases? |
@Algy Hi. The new functionality eliminates some instructions in the dequantization phase for the 4bit x 16bit case and int8 x 16bit case. It is expected to have ~3% improvement when dequantization is not negligible, eg., when problem K is very small. You can turn on/off this feature in the int4xbf16 example by changing the There's additional improvements for all the cases from refactoring the dequantization codes. The other improvements are refactoring and improving the robustness. Eg., you can now use the argument |
* Handle MNK Sm90{Row, Col}Reduction problem shapes (NVIDIA#1803) * add is_last_tile * Improve sm90 mixed dtype kernel (NVIDIA#1883) * Add GMMA shape m64n40k16 (NVIDIA#1864) * Add all supported GMMA shapes (NVIDIA#1890) * add maximum support (NVIDIA#1833) * fix typo (NVIDIA#1853) * fix by adding public (NVIDIA#1753) * added mapping for bf16 to torch::kBFloat16 (NVIDIA#1843) Co-authored-by: Haicheng Wu <[email protected]> * Fix README (NVIDIA#1658) * Fix README * Improve README --------- Co-authored-by: Haicheng Wu <[email protected]> * Adjusting code indentation (NVIDIA#1639) * Include of regular_tile_iterator.h fixed for NVRTC (NVIDIA#1765) * Include of regular_tile_iterator.h fixed for NVRTC * More include fixed for NVRTC * Update gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu with include "cutlass/gemm/device/gemm_universal.h" (NVIDIA#1569) fix compile with `cmake .. -DCUTLASS_ENABLE_TESTS=ON -DCUTLASS_TEST_LEVEL=2` * remove redundant hardcoded packing configs in mixed dtype gemm (NVIDIA#1894) Co-authored-by: Siyuan Fu <[email protected]> * fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support (NVIDIA#1856) * fix wrong A/BLayout in MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for m8n8k128, m16n8k128 mma.and.popc in MMA_Traits instantiation * add "print" template for subbyte_reference<T> * Add a print for the uint{x}b_t type. (NVIDIA#1871) * Refactor some GroupedGEMM logic (NVIDIA#1899) * feat: support kFactor 8 used in mma tensor op tile iterator (NVIDIA#1512) * Update publications (NVIDIA#1912) * remove restriction of stride == kernel in nhwc_pooling (NVIDIA#1896) * fix undefined in device code error (NVIDIA#1880) * Fix the racing condition of mixed-input gemm when writing the registers (NVIDIA#1931) * move two warpgroup_wait * merge main --------- Co-authored-by: Siyuan Fu <[email protected]> * Fix `cutlass` python library with cuda `12.6.2.post1` (NVIDIA#1942) * Fix `cutlass` python library with cuda `12.6.2.post1` Previously we had this error: ``` File "/storage/home/cutlass/python/cutlass/backend/operation.py", line 39, in <listcomp> _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")] ^^^^^^ ValueError: invalid literal for int() with base 10: 'post1' ``` * Update sm90_utils.py * Update generator.py * Update python/cutlass_library/generator.py Co-authored-by: Jack Kosaian <[email protected]> * Update python/cutlass_library/sm90_utils.py Co-authored-by: Jack Kosaian <[email protected]> --------- Co-authored-by: Jack Kosaian <[email protected]> * add {uint4, uint2, int2} => {fp16, bf16} conversion (NVIDIA#1966) * Improve mixed dtype GEMM (NVIDIA#1972) * update * fix a typo * fix a typo that fails the compiling when ElementScale is not the same as MmaType (NVIDIA#1977) * Fix CuTe README Typo (NVIDIA#1951) * Fix Typo (NVIDIA#1962) * 3.6.0 update (NVIDIA#2005) * 3.6.0 update * doc and swap stuff --------- Co-authored-by: yuzhai <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * Update CHANGELOG.md * Update 0x_gemm_tutorial.md (NVIDIA#1982) Shouldn't this be BLK_M, BLK_**K**, k * fix bug: arch/mma_sm60.h Mma<2,2,1> calculate wrong (NVIDIA#1989) * fix mem fence (NVIDIA#2030) Co-authored-by: yuzhai <[email protected]> * Add half->int8 saturate conversion to promise valid range (NVIDIA#1983) * Add half->int8 saturate conversion to promise valid range * add gpu only macro --------- Co-authored-by: Haicheng Wu <[email protected]> * Add vector-types back to platform.h (NVIDIA#2026) * Fix typo in library_defaults.py (NVIDIA#2024) * Fix Typos (NVIDIA#2021) * Fix Typo * Fix Typo * Add Line Break (NVIDIA#2020) * Blockwise Scaling for FP8 (NVIDIA#1932) * F8 Blockwise Scaling * two more NumProducerThreadEvents --------- Co-authored-by: Haicheng Wu <[email protected]> * fix assertion in integer_subbytes.h (NVIDIA#1961) * CUTLASS 3.7 (NVIDIA#2045) * CUTLASS 3.7 * clean up changelog --------- Co-authored-by: yuzhai <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * update 3.7 docs (NVIDIA#2051) * update docs * update docs * update docs * update docs * update docs --------- Co-authored-by: yuzhai <[email protected]> * CUTLASS 3.8 Release (NVIDIA#2059) * CUTLASS 3.8 Release * update * Update README.md * Revert "Update README.md" This reverts commit b353e36. * update * update --------- Co-authored-by: Haicheng Wu <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * fix cuda 12.6 issues (NVIDIA#2066) * fix a readme broken link (NVIDIA#2069) * Update README.md * Groupwise scaling along M for FP8 gemm (NVIDIA#2037) * FP8 groupwise scaling along M * small updates --------- Co-authored-by: zl <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * bugfix generic-k code in top-k with softmax (NVIDIA#1993) * bugfix generic-k code in top-k with softmax * Update include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp Co-authored-by: Ali Hassani <[email protected]> * Update examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu Co-authored-by: Ali Hassani <[email protected]> --------- Co-authored-by: Ali Hassani <[email protected]> * [EVT] Add support for Row/Col broadcast PtrArray (NVIDIA#2033) * Add group support to EVT row/col broadcast. * small modifications --------- Co-authored-by: Haicheng Wu <[email protected]> * v3.8.0 update (NVIDIA#2082) * 3.8 update * fix Markus' name --------- Co-authored-by: yuzhai <[email protected]> * [WA] Fix compiling errors --------- Co-authored-by: Saagar Jha <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> Co-authored-by: Sergey Klevtsov <[email protected]> Co-authored-by: Tri Dao <[email protected]> Co-authored-by: Xinyu Yang <[email protected]> Co-authored-by: sijialou <[email protected]> Co-authored-by: Bogumil Sapinski Mobica <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> Co-authored-by: Lei Mao <[email protected]> Co-authored-by: 103yiran <[email protected]> Co-authored-by: MaxAkaAltmer <[email protected]> Co-authored-by: 侯奇 <[email protected]> Co-authored-by: Lain <[email protected]> Co-authored-by: Siyuan Fu <[email protected]> Co-authored-by: Caleb_Du <[email protected]> Co-authored-by: LiYu Lu <[email protected]> Co-authored-by: azhurkevich <[email protected]> Co-authored-by: chenwei <[email protected]> Co-authored-by: Wenlei Bao <[email protected]> Co-authored-by: LiuQiang <[email protected]> Co-authored-by: dan_the_3rd <[email protected]> Co-authored-by: Jack Kosaian <[email protected]> Co-authored-by: Yujia Zhai <[email protected]> Co-authored-by: yuzhai <[email protected]> Co-authored-by: Andrew O'Neill <[email protected]> Co-authored-by: Dongxu.Wang <[email protected]> Co-authored-by: ZZK <[email protected]> Co-authored-by: Driss Guessous <[email protected]> Co-authored-by: ZincCat <[email protected]> Co-authored-by: Manish Gupta <[email protected]> Co-authored-by: bobliao <[email protected]> Co-authored-by: mihir-awatramani <[email protected]> Co-authored-by: Liang <[email protected]> Co-authored-by: zl <[email protected]> Co-authored-by: Tadej Ciglarič <[email protected]> Co-authored-by: Ali Hassani <[email protected]> Co-authored-by: Josh Fromm <[email protected]>
* update * fix a typo
MixedInput
in the mixed dtype GEMM's collective schedules. See Eliminate MixedInput kernel schedule tags. #1956examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu
andinclude/cutlass/detail/collective/mixed_input_utils.hpp
for more details