Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EVT for cutlass::gemm::kernel::DefaultGemmWithVisitor's behavior when constructing GemmUniversalAdapter #1753

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

Xinyu302
Copy link
Contributor

@Xinyu302 Xinyu302 commented Aug 28, 2024

The phenomenon I encountered while using the epilogue visitor tree in C++ is quite strange...
In examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu:

using EVTKernelStreamK =
    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
    ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
    ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
    ElementC, LayoutC, AlignmentC,
    ElementAccumulator,
    ElementCompute,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    ThreadblockShape,
    WarpShape,
    InstructionShape,
    EVTD,
    cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
    NumStages,
    cutlass::arch::OpMultiplyAdd,
    EVTEpilogueStages
>::GemmKernel;

using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;

I didn't want to use cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, so I replaced it with cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>.
When I construct DeviceGemmStreamK::Arguments, it errors like this:

type "cutlass::gemm::kernel::GemmUniversal<Mma_, Epilogue_, ThreadblockSwizzle_, void, std::enable_if_t<<expression>, void>>::WarpShape ...
/workspace/cutlass/python/cutlass_library/../../include/cutlass/gemm/kernel/gemm_universal.h(94): here is inaccessible

I think it is because when if we don't use ThreadblockSwizzleStreamK, DefaultGemmWithVisitor will choose GemmWithEpilogueVisitor as the type of GemmKernel. So EVTKernelStreamK will be a subclass of GemmWithEpilogueVisitor in include/cutlass/gemm/kernel/gemm_universal_with_visitor.h.

The GemmWithEpilogueVisitor is a subclass of GemmUniversal, GemmUniversal has all the types we need to help us construct GemmUniversalAdapter, but when GemmWithEpilogueVisitor inherits from GemmUniversal, it doesn't use public inheritance, which means that the required types like Mma, WarpShape are not accessible in GemmWithEpilogueVisitor.

So, I changed the inheritance relationship to public to ensure the consistency of DefaultGemmWithVisitor behavior. Besides this, I also need to modify the GemmWithEpilogueVisitor class.

I would like to hear your suggestions on this issue.

@Xinyu302 Xinyu302 changed the title [Fix] Fix cutlass::gemm::kernel::DefaultGemmWithVisitor's behavior when constructing GemmUniversalAdapter Fix cutlass::gemm::kernel::DefaultGemmWithVisitor's behavior when constructing GemmUniversalAdapter Aug 28, 2024
@Xinyu302
Copy link
Contributor Author

@apuaaChen Could you please help to take a look at this problem?

@gau-nernst
Copy link

Thank you for the fix. I tried this locally and it works wonder! Much faster than ThreadblockSwizzleStreamK. Hope the maintainers can merge this.

@Xinyu302 Xinyu302 changed the title Fix cutlass::gemm::kernel::DefaultGemmWithVisitor's behavior when constructing GemmUniversalAdapter Fix EVT for cutlass::gemm::kernel::DefaultGemmWithVisitor's behavior when constructing GemmUniversalAdapter Sep 21, 2024
@Xinyu302
Copy link
Contributor Author

@hwu36 Could you please take a look at this PR?

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@hwu36 hwu36 merged commit b0c09ed into NVIDIA:main Oct 23, 2024
sijialouintel added a commit to sijialouintel/cutlass that referenced this pull request Feb 12, 2025
* Handle MNK Sm90{Row, Col}Reduction problem shapes (NVIDIA#1803)

* add is_last_tile

* Improve sm90 mixed dtype kernel (NVIDIA#1883)

* Add GMMA shape m64n40k16 (NVIDIA#1864)

* Add all supported GMMA shapes (NVIDIA#1890)

* add maximum support (NVIDIA#1833)

* fix typo (NVIDIA#1853)

* fix by adding public (NVIDIA#1753)

* added mapping for bf16 to torch::kBFloat16 (NVIDIA#1843)

Co-authored-by: Haicheng Wu <[email protected]>

* Fix README (NVIDIA#1658)

* Fix README

* Improve README

---------

Co-authored-by: Haicheng Wu <[email protected]>

* Adjusting code indentation (NVIDIA#1639)

* Include of regular_tile_iterator.h fixed for NVRTC (NVIDIA#1765)

* Include of regular_tile_iterator.h fixed for NVRTC

* More include fixed for NVRTC

* Update gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu with include "cutlass/gemm/device/gemm_universal.h" (NVIDIA#1569)

fix compile with `cmake .. -DCUTLASS_ENABLE_TESTS=ON -DCUTLASS_TEST_LEVEL=2`

* remove redundant hardcoded packing configs in mixed dtype gemm (NVIDIA#1894)

Co-authored-by: Siyuan Fu <[email protected]>

* fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support  (NVIDIA#1856)

* fix wrong A/BLayout in  MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for  m8n8k128, m16n8k128  mma.and.popc in MMA_Traits instantiation

* add "print" template for  subbyte_reference<T>

* Add a print for the uint{x}b_t type. (NVIDIA#1871)

* Refactor some GroupedGEMM logic (NVIDIA#1899)

* feat: support kFactor 8 used in mma tensor op tile iterator (NVIDIA#1512)

* Update publications (NVIDIA#1912)

* remove restriction of stride == kernel in nhwc_pooling (NVIDIA#1896)

* fix undefined in device code error (NVIDIA#1880)

* Fix the racing condition of mixed-input gemm when writing the registers (NVIDIA#1931)

* move two warpgroup_wait

* merge main

---------

Co-authored-by: Siyuan Fu <[email protected]>

* Fix `cutlass` python library with cuda `12.6.2.post1` (NVIDIA#1942)

* Fix `cutlass` python library with cuda `12.6.2.post1`

Previously we had this error:
```
  File "/storage/home/cutlass/python/cutlass/backend/operation.py", line 39, in <listcomp>
    _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")]
                       ^^^^^^
ValueError: invalid literal for int() with base 10: 'post1'
```

* Update sm90_utils.py

* Update generator.py

* Update python/cutlass_library/generator.py

Co-authored-by: Jack Kosaian <[email protected]>

* Update python/cutlass_library/sm90_utils.py

Co-authored-by: Jack Kosaian <[email protected]>

---------

Co-authored-by: Jack Kosaian <[email protected]>

* add {uint4, uint2, int2} => {fp16, bf16} conversion (NVIDIA#1966)

* Improve mixed dtype GEMM (NVIDIA#1972)

* update

* fix a typo

* fix a typo that fails the compiling when ElementScale is not the same as MmaType (NVIDIA#1977)

* Fix CuTe README Typo (NVIDIA#1951)

* Fix Typo (NVIDIA#1962)

* 3.6.0 update (NVIDIA#2005)

* 3.6.0 update

* doc and swap stuff

---------

Co-authored-by: yuzhai <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* Update CHANGELOG.md

* Update 0x_gemm_tutorial.md (NVIDIA#1982)

Shouldn't this be BLK_M, BLK_**K**, k

* fix bug: arch/mma_sm60.h Mma<2,2,1> calculate wrong (NVIDIA#1989)

* fix mem fence (NVIDIA#2030)

Co-authored-by: yuzhai <[email protected]>

* Add half->int8 saturate conversion to promise valid range (NVIDIA#1983)

* Add half->int8 saturate conversion to promise valid range

* add gpu only macro

---------

Co-authored-by: Haicheng Wu <[email protected]>

* Add vector-types back to platform.h (NVIDIA#2026)

* Fix typo in library_defaults.py (NVIDIA#2024)

* Fix Typos (NVIDIA#2021)

* Fix Typo

* Fix Typo

* Add Line Break (NVIDIA#2020)

* Blockwise Scaling for FP8 (NVIDIA#1932)

* F8 Blockwise Scaling

* two more NumProducerThreadEvents

---------

Co-authored-by: Haicheng Wu <[email protected]>

* fix assertion in integer_subbytes.h (NVIDIA#1961)

* CUTLASS 3.7 (NVIDIA#2045)

* CUTLASS 3.7

* clean up changelog

---------

Co-authored-by: yuzhai <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* update 3.7 docs (NVIDIA#2051)

* update docs

* update docs

* update docs

* update docs

* update docs

---------

Co-authored-by: yuzhai <[email protected]>

* CUTLASS 3.8 Release (NVIDIA#2059)

* CUTLASS 3.8 Release

* update

* Update README.md

* Revert "Update README.md"

This reverts commit b353e36.

* update

* update

---------

Co-authored-by: Haicheng Wu <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* fix cuda 12.6 issues (NVIDIA#2066)

* fix a readme broken link (NVIDIA#2069)

* Update README.md

* Groupwise scaling along M for FP8 gemm (NVIDIA#2037)

* FP8 groupwise scaling along M

* small updates

---------

Co-authored-by: zl <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* bugfix generic-k code in top-k with softmax (NVIDIA#1993)

* bugfix generic-k code in top-k with softmax

* Update include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp

Co-authored-by: Ali Hassani <[email protected]>

* Update examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu

Co-authored-by: Ali Hassani <[email protected]>

---------

Co-authored-by: Ali Hassani <[email protected]>

* [EVT] Add support for Row/Col broadcast PtrArray (NVIDIA#2033)

* Add group support to EVT row/col broadcast.

* small modifications

---------

Co-authored-by: Haicheng Wu <[email protected]>

* v3.8.0 update (NVIDIA#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <[email protected]>

* [WA] Fix compiling errors

---------

Co-authored-by: Saagar Jha <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
Co-authored-by: Sergey Klevtsov <[email protected]>
Co-authored-by: Tri Dao <[email protected]>
Co-authored-by: Xinyu Yang <[email protected]>
Co-authored-by: sijialou <[email protected]>
Co-authored-by: Bogumil Sapinski Mobica <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
Co-authored-by: Lei Mao <[email protected]>
Co-authored-by: 103yiran <[email protected]>
Co-authored-by: MaxAkaAltmer <[email protected]>
Co-authored-by: 侯奇 <[email protected]>
Co-authored-by: Lain <[email protected]>
Co-authored-by: Siyuan Fu <[email protected]>
Co-authored-by: Caleb_Du <[email protected]>
Co-authored-by: LiYu Lu <[email protected]>
Co-authored-by: azhurkevich <[email protected]>
Co-authored-by: chenwei <[email protected]>
Co-authored-by: Wenlei Bao <[email protected]>
Co-authored-by: LiuQiang <[email protected]>
Co-authored-by: dan_the_3rd <[email protected]>
Co-authored-by: Jack Kosaian <[email protected]>
Co-authored-by: Yujia Zhai <[email protected]>
Co-authored-by: yuzhai <[email protected]>
Co-authored-by: Andrew O'Neill <[email protected]>
Co-authored-by: Dongxu.Wang <[email protected]>
Co-authored-by: ZZK <[email protected]>
Co-authored-by: Driss Guessous <[email protected]>
Co-authored-by: ZincCat <[email protected]>
Co-authored-by: Manish Gupta <[email protected]>
Co-authored-by: bobliao <[email protected]>
Co-authored-by: mihir-awatramani <[email protected]>
Co-authored-by: Liang <[email protected]>
Co-authored-by: zl <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
Co-authored-by: Ali Hassani <[email protected]>
Co-authored-by: Josh Fromm <[email protected]>
hgl71964 pushed a commit to hgl71964/cutlass that referenced this pull request Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants