From b12b66f11e36cc6747ce469ed8595a80f730f744 Mon Sep 17 00:00:00 2001 From: yuzhai Date: Fri, 20 Dec 2024 09:59:25 -0800 Subject: [PATCH 1/2] 3.6.0 update --- CHANGELOG.md | 51 +- CMakeLists.txt | 168 +- CUDA.cmake | 93 +- README.md | 1 + .../13_two_tensor_op_fusion/CMakeLists.txt | 1 - .../gather_scatter_fusion.cu | 10 +- examples/39_gemm_permute/layouts.h | 4 - .../debug_utils.h | 4 +- .../epilogue/epilogue_pipelined.h | 7 - .../epilogue/epilogue_rescale_output.h | 8 - .../fmha_grouped.h | 18 +- .../gemm_kernel_utils.h | 5 +- .../kernel_backward.h | 13 +- .../kernel_forward.h | 13 +- .../threadblock/fused_bias_act_epilogue.h | 7 - .../ir_gen/gen_sample.py | 2 +- .../ir_gen/gen_turing_and_volta.py | 2 +- examples/45_dual_gemm/test_run.h | 2 +- .../45_dual_gemm/threadblock/dual_epilogue.h | 4 - .../52_hopper_gather_scatter_fusion.cu | 15 +- .../55_hopper_int4_bf16_gemm.cu | 3 +- .../55_hopper_int4_fp8_gemm.cu | 2 +- examples/55_hopper_mixed_dtype_gemm/README.md | 2 +- .../packed_scale.hpp | 2 +- .../reorder_utils.hpp | 2 +- .../57_hopper_grouped_gemm.cu | 2 +- .../collective/builder.hpp | 31 +- ..._gmma_ss_warpspecialized_with_prefetch.hpp | 31 +- .../64_ada_fp8_gemm_grouped/CMakeLists.txt | 35 + .../ada_fp8_gemm_grouped.cu | 1208 +++++++++++++++ examples/CMakeLists.txt | 2 + examples/cute/tutorial/tiled_copy.cu | 57 +- include/cute/algorithm/cooperative_copy.hpp | 56 +- include/cute/algorithm/cooperative_gemm.hpp | 527 ++++--- include/cute/algorithm/copy.hpp | 453 ++++-- include/cute/arch/copy.hpp | 26 +- include/cute/arch/mma_sm80.hpp | 50 +- include/cute/atom/copy_atom.hpp | 10 +- include/cute/atom/copy_traits.hpp | 33 +- include/cute/atom/copy_traits_sm80.hpp | 65 +- include/cute/atom/copy_traits_sm90_tma.hpp | 183 ++- include/cute/atom/mma_atom.hpp | 49 +- include/cute/atom/mma_traits_sm80.hpp | 207 ++- include/cute/atom/mma_traits_sm90_gmma.hpp | 1 - .../cute/atom/mma_traits_sm90_gmma_sparse.hpp | 2 +- include/cute/config.hpp | 2 +- include/cute/container/array_subbyte.hpp | 39 +- include/cute/layout.hpp | 2 +- include/cute/layout_composed.hpp | 11 +- include/cute/numeric/integral_ratio.hpp | 43 +- include/cute/numeric/numeric_types.hpp | 1 + include/cute/pointer.hpp | 8 + include/cute/pointer_base.hpp | 22 +- include/cute/pointer_swizzle.hpp | 8 + include/cute/swizzle_layout.hpp | 4 +- include/cute/tensor_impl.hpp | 13 + include/cute/util/debug.hpp | 4 +- include/cute/util/type_traits.hpp | 6 + include/cutlass/arch/barrier.h | 93 ++ include/cutlass/arch/config.h | 4 + include/cutlass/arch/memory_sm75.h | 19 +- include/cutlass/arch/mma_sm70.h | 4 - include/cutlass/arch/mma_sm75.h | 4 - include/cutlass/arch/mma_sm80.h | 4 - include/cutlass/arch/mma_sm89.h | 4 - include/cutlass/arch/mma_sm90.h | 4 - include/cutlass/arch/mma_sparse_sm80.h | 4 - include/cutlass/arch/mma_sparse_sm89.h | 4 - include/cutlass/arch/simd.h | 4 +- include/cutlass/arch/synclog.hpp | 2 +- include/cutlass/arch/wmma_sm70.h | 4 - include/cutlass/arch/wmma_sm72.h | 4 - include/cutlass/arch/wmma_sm75.h | 4 - include/cutlass/array.h | 19 +- include/cutlass/array_subbyte.h | 2 + include/cutlass/blas3.h | 2 +- ..._implicit_gemm_gmma_ss_warpspecialized.hpp | 100 +- include/cutlass/conv/convnd_problem_shape.hpp | 40 + .../cutlass/conv/kernel/direct_convolution.h | 3 +- include/cutlass/coord.h | 2 +- include/cutlass/cuda_host_adapter.hpp | 9 +- include/cutlass/detail/collective.hpp | 1 + .../detail/collective/mixed_input_utils.hpp | 1 - include/cutlass/detail/helper_macros.hpp | 10 +- .../mainloop_fusion_helper_scale_factor.hpp | 75 + include/cutlass/device_kernel.h | 5 +- .../collective/builders/sm90_builder.inl | 3 +- .../collective/default_epilogue_array.hpp | 38 +- .../cutlass/epilogue/collective/detail.hpp | 15 +- ...m90_epilogue_array_tma_warpspecialized.hpp | 45 +- .../sm90_epilogue_tma_warpspecialized.hpp | 28 +- include/cutlass/epilogue/dispatch_policy.hpp | 1 + .../cutlass/epilogue/fusion/operations.hpp | 144 ++ .../sm90_callbacks_tma_warpspecialized.hpp | 1081 +++++++++++-- ...90_visitor_compute_tma_warpspecialized.hpp | 12 +- .../sm90_visitor_load_tma_warpspecialized.hpp | 174 ++- ...sm90_visitor_store_tma_warpspecialized.hpp | 52 +- .../sm90_visitor_tma_warpspecialized.hpp | 2 - include/cutlass/epilogue/thread/activation.h | 72 +- .../epilogue/thread/linear_combination.h | 6 +- .../cutlass/epilogue/threadblock/epilogue.h | 14 +- .../epilogue/threadblock/epilogue_base.h | 4 - .../threadblock/epilogue_gemm_k_reduction.h | 4 - .../threadblock/epilogue_smem_accumulator.h | 4 - .../epilogue_streamk_with_broadcast.h | 4 +- .../threadblock/epilogue_with_absmax.h | 4 +- .../threadblock/epilogue_with_broadcast.h | 4 +- .../threadblock/epilogue_with_reduction.h | 4 - .../epilogue_with_visitor_callbacks.h | 22 + .../threadblock/fusion/visitor_load.hpp | 24 +- .../threadblock/fusion/visitor_store.hpp | 5 +- .../threadblock/output_tile_thread_map.h | 4 +- .../warp/tile_iterator_tensor_op_mixed.h | 14 +- include/cutlass/float8.h | 4 +- include/cutlass/floating_point_nvrtc.h | 6 + include/cutlass/functional.h | 2 +- .../collective/builders/sm90_gmma_builder.inl | 31 +- .../collective/collective_builder_decl.hpp | 12 + .../gemm/collective/collective_mma.hpp | 3 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 1370 +++++++++++++++++ ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 2 +- .../gemm/device/gemm_universal_adapter.h | 77 +- include/cutlass/gemm/dispatch_policy.hpp | 5 +- .../gemm/group_array_problem_shape.hpp | 2 +- .../default_gemm_grouped_per_group_scale.h | 384 +++++ include/cutlass/gemm/kernel/ell_gemm.h | 6 +- .../kernel/gemm_grouped_per_group_scale.h | 261 ++++ .../gemm_universal_with_visitor_streamk.h | 4 +- .../gemm/kernel/grouped_problem_visitor.h | 4 +- .../gemm/kernel/params_universal_base.h | 2 +- include/cutlass/gemm/kernel/rank_2k_grouped.h | 16 +- ..._array_tma_warpspecialized_cooperative.hpp | 58 +- ...emm_array_tma_warpspecialized_pingpong.hpp | 57 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 43 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 67 +- .../gemm/kernel/sm90_tile_scheduler.hpp | 2 +- .../gemm/kernel/sm90_tile_scheduler_group.hpp | 10 +- .../kernel/sm90_tile_scheduler_stream_k.hpp | 578 ++++--- .../cutlass/gemm/kernel/tile_scheduler.hpp | 12 +- .../gemm/kernel/tile_scheduler_params.h | 721 ++++++--- include/cutlass/gemm/thread/mma_sm50.h | 46 +- .../gemm/threadblock/ell_mma_multistage.h | 16 +- include/cutlass/integer_subbyte.h | 8 +- include/cutlass/kernel_launch.h | 1 + include/cutlass/layout/permute.h | 5 +- include/cutlass/layout/tensor.h | 6 +- .../layout/tensor_op_multiplicand_sm70.h | 1 + include/cutlass/numeric_conversion.h | 300 ++-- include/cutlass/numeric_size.h | 9 + include/cutlass/numeric_types.h | 2 - include/cutlass/pipeline/sm90_pipeline.hpp | 302 +++- include/cutlass/platform/platform.h | 29 +- include/cutlass/predicate_vector.h | 7 +- include/cutlass/real.h | 2 + .../reduction/thread/reduction_operators.h | 4 +- include/cutlass/tensor_view_planar_complex.h | 1 + include/cutlass/tfloat32.h | 1 + .../device/transform_universal_adapter.hpp | 2 +- .../kernel/sm90_sparse_gemm_compressor.hpp | 2 + media/docs/cute/04_algorithms.md | 2 +- media/docs/quickstart.md | 1 - python/cutlass/__init__.py | 13 + python/cutlass/backend/c_types.py | 2 +- python/cutlass/backend/compiler.py | 4 +- python/cutlass/backend/epilogue.py | 1 + python/cutlass/epilogue/evt_ops.py | 5 +- python/cutlass/library_defaults.py | 16 +- python/cutlass_library/gemm_operation.py | 1 + python/cutlass_library/generator.py | 8 +- python/cutlass_library/library.py | 12 + python/cutlass_library/sm90_utils.py | 26 +- .../python/cutlass/evt/evt_compute_sm80_90.py | 1 - test/self_contained_includes/CMakeLists.txt | 109 ++ test/unit/CMakeLists.txt | 10 +- test/unit/common/filter_architecture.cpp | 1 - .../conv/device_3x/conv_problem_sizes.hpp | 31 + test/unit/conv/device_3x/testbed_conv.hpp | 12 +- test/unit/core/float8.cu | 1 + test/unit/cute/ampere/CMakeLists.txt | 2 +- test/unit/cute/ampere/cooperative_copy.cu | 15 +- test/unit/cute/ampere/cooperative_gemm.cu | 558 ++++--- .../cute/ampere/{cp_async.cu => cp_sync.cu} | 4 +- test/unit/cute/cooperative_gemm_common.hpp | 852 ++++++---- test/unit/cute/core/inverse_left.cpp | 2 +- test/unit/cute/hopper/cooperative_gemm.cu | 93 +- test/unit/cute/hopper/tma_load_testbed.hpp | 2 +- .../cute/hopper/tma_mcast_load_testbed.hpp | 2 +- test/unit/cute/turing/cooperative_gemm.cu | 16 +- test/unit/cute/volta/cooperative_gemm.cu | 365 ++--- test/unit/cute/volta/vectorization_auto.cu | 1 - test/unit/gemm/device/CMakeLists.txt | 207 +-- test/unit/gemm/device/gemm_testbed_3x.hpp | 37 +- .../gemm/device/gemm_testbed_3x_ptr_array.hpp | 39 +- .../device/sm90_gett_f16_f16_f16_tensor_op.cu | 184 +++ test/unit/transform/CMakeLists.txt | 2 +- .../sm90_sparse_gemm_compressor_legacy.hpp | 5 +- .../device/testbed_sparse_gemm_compressor.hpp | 1 + test/unit/transform/kernel/CMakeLists.txt | 2 +- test/unit/util/rms_norm.cu | 2 +- .../library/include/cutlass/library/handle.h | 2 + .../library/include/cutlass/library/library.h | 7 +- tools/library/src/conv2d_operation.h | 14 +- tools/library/src/conv3d_operation.h | 7 +- tools/library/src/conv_operation_3x.hpp | 158 +- tools/library/src/gemm_operation.h | 42 +- tools/library/src/gemm_operation_3x.hpp | 6 +- tools/library/src/handle.cu | 31 +- tools/library/src/library_internal.h | 8 + tools/library/src/rank_2k_operation.h | 7 +- tools/library/src/rank_k_operation.h | 7 +- .../src/reduction/reduction_operation.h | 7 +- .../src/reference/conv_reference_operation.h | 7 +- .../src/reference/gemm_reference_operation.h | 7 +- tools/library/src/reference/gemm_s8_s8_s32.cu | 2 +- .../library/src/sparse_gemm_operation_3x.hpp | 51 +- tools/library/src/symm_operation.h | 7 +- tools/library/src/trmm_operation.h | 7 +- tools/profiler/CMakeLists.txt | 2 +- .../profiler/conv2d_operation_profiler.h | 2 +- .../profiler/conv3d_operation_profiler.h | 2 +- .../include/cutlass/profiler/cublas_helpers.h | 2 +- .../cutlass/profiler/enumerated_types.h | 4 +- .../profiler/gemm_operation_profiler.h | 9 +- .../include/cutlass/profiler/gpu_timer.h | 17 +- .../cutlass/profiler/operation_profiler.h | 2 +- .../cutlass/profiler/performance_result.h | 3 + .../profiler/src/conv2d_operation_profiler.cu | 29 +- .../profiler/src/conv3d_operation_profiler.cu | 6 +- tools/profiler/src/cublas_helpers.cu | 4 +- tools/profiler/src/device_allocation.cu | 56 +- tools/profiler/src/gemm_operation_profiler.cu | 925 ++++++----- tools/profiler/src/gpu_timer.cpp | 31 +- tools/profiler/src/operation_profiler.cu | 4 +- tools/profiler/src/options.cu | 25 +- tools/profiler/src/performance_report.cpp | 20 +- .../src/rank_2k_operation_profiler.cu | 2 +- .../profiler/src/rank_k_operation_profiler.cu | 2 +- .../src/sparse_gemm_operation_profiler.cu | 2 +- tools/profiler/src/symm_operation_profiler.cu | 2 +- tools/profiler/src/trmm_operation_profiler.cu | 2 +- tools/util/include/cutlass/util/device_dump.h | 2 +- .../include/cutlass/util/device_groupnorm.h | 2 +- .../include/cutlass/util/device_layernorm.h | 2 +- .../cutlass/util/device_nhwc_pooling.h | 2 +- .../include/cutlass/util/device_rmsnorm.h | 6 +- .../util/include/cutlass/util/device_utils.h | 2 +- .../util/include/cutlass/util/distribution.h | 3 + .../util/reference/device/tensor_fill.h | 26 +- .../cutlass/util/reference/host/gett.hpp | 1 - .../util/reference/host/rank_2k_complex.h | 2 +- .../util/reference/host/rank_k_complex.h | 2 +- .../util/reference/host/symm_complex.h | 2 +- .../cutlass/util/reference/host/tensor_fill.h | 16 +- tools/util/include/cutlass/util/type_traits.h | 2 +- 254 files changed, 10774 insertions(+), 3801 deletions(-) create mode 100644 examples/64_ada_fp8_gemm_grouped/CMakeLists.txt create mode 100644 examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu create mode 100644 include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp create mode 100644 include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp create mode 100644 include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h create mode 100644 include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h rename test/unit/cute/ampere/{cp_async.cu => cp_sync.cu} (97%) create mode 100644 test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu diff --git a/CHANGELOG.md b/CHANGELOG.md index c98cdb515f..2d0140a586 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,4 @@ # NVIDIA CUTLASS Changelog - ## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03) - [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). @@ -18,7 +17,27 @@ - [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) - Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). - Various improvements and fixed from the community and CUTLASS team. Thanks to everyone who submitted PRs! +- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! + +- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) +- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: + + [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411). + + [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456). +- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. +- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). +- Support for residual add (beta != 0) in convolution kernels. +- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. +- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- Better support for MSVC as a host compiler. +- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. +- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. ## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25) @@ -51,7 +70,7 @@ + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! - Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x +- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. - 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. @@ -82,7 +101,7 @@ * [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. * [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}. * [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors. -* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors. +* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors. * Profiler support for lower-aligned Hopper GEMMs. * Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion). * Sub-Byte type fixes and improvements. @@ -159,10 +178,10 @@ * [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary. * Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N. * [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added - * [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM. + * [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM. * The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration. * [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded. - * The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration. + * The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration. * [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/). * [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115). * Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers). @@ -173,13 +192,13 @@ * CUDA 10.2 ## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23) -* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. +* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. * Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too. * Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance. * [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing. * [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues. * [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue. -* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes: +* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes: * kSingleGroup: output channel per group is multiple of Threadblock tile N. * kMultipleGroup: Threadblock tile N is multiple of output channel per group. * [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number. @@ -235,7 +254,7 @@ * [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu) * **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu * [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu) - * [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu) + * [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu) * [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h) * **Grouped GEMM:** similar to batched GEMM with distinct problem size per group * [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM @@ -274,7 +293,7 @@ * [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) * [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) * 64b tensor strides and leading dimensions support for GEMMs - * Affine rank=2 matrix layouts + * Affine rank=2 matrix layouts * Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h) * Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM. * [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation @@ -289,7 +308,7 @@ * Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations * Performance improvement for FP16 tensor core kernels * Bug fixes - * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. + * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. * Updated minimum CUDA Toolkit requirement to 10.2 * [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended * Corrections and bug fixes reported by the CUTLASS community @@ -308,7 +327,7 @@ * [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md) * Corrections and bug fixes reported by the CUTLASS community * Thank you for filing these issues! - + ## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19) * Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs @@ -316,7 +335,7 @@ * Data type: FP32, complex, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32 * Spatial dimensions: 1-D, 2-D, and 3-D * Layout: NHWC, NCxHWx - * Implicit GEMM convolution components: + * Implicit GEMM convolution components: * Global memory iterators supporting Fprop, Dgrad, and Wgrad * `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture * `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures @@ -332,17 +351,17 @@ * Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code * [Floating-point constants](./include/cutlass/constants.h) * NVIDIA Ampere GPU Architecture examples and documentation: - * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and + * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and * [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue) ## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) - * Fast Tensor Core operations: + * Fast Tensor Core operations: * Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends) * Tensor Float 32, BFloat16, and double-precision data types * Mixed integer data types (int8, int4, bin1) - * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) + * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) * Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required) * Features: * SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM diff --git a/CMakeLists.txt b/CMakeLists.txt index e61b66a877..e9c501bc2b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,18 @@ project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_C ################################################################################ +if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") + set(CUTLASS_GNU_HOST_COMPILE ON CACHE BOOL "Using GNU tools for host code compilation") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "[Cc]lang") + set(CUTLASS_CLANG_HOST_COMPILE ON CACHE BOOL "Using Clang tools for host code compilation") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set(CUTLASS_MSVC_HOST_COMPILE ON CACHE BOOL "Using MSVC tools for host code compilation") +endif() + +################################################################################ + include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -67,11 +79,11 @@ elseif (CUDA_VERSION VERSION_LESS 11.4) message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.") endif() -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3) +if(CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3) message(FATAL_ERROR "GCC version must be at least 7.3!") endif() -if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) +if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") endif() find_package(Doxygen QUIET) @@ -85,13 +97,10 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -if(CUTLASS_NATIVE_CUDA) - set(CMAKE_CUDA_STANDARD 17) - set(CMAKE_CUDA_STANDARD_REQUIRED ON) - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) -else() - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17) -endif() +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE) @@ -146,13 +155,13 @@ endif() ################################################################################ set(CUTLASS_NVCC_ARCHS_SUPPORTED "") -if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87) endif() -if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90) endif() -if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") @@ -246,7 +255,7 @@ set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH") if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS) # If a kernel filter file is specified, we want to generate and then # filter on the entire kernel set, not the default kernel - # (sub)set. The user may have overridden CUTLASS_LIBRRARY_KERNELS, in which + # (sub)set. The user may have overridden CUTLASS_LIBRARY_KERNELS, in which # case the resulting kernel set will be the intersection of the two # options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS. set(CUTLASS_LIBRARY_KERNELS_INIT "*") @@ -375,15 +384,22 @@ endif() # Warnings-as-error exceptions and warning suppressions for Clang builds -if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=implicit-int-conversion ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=implicit-int-conversion" ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pass-failed ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=pass-failed" ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=inconsistent-missing-override ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=inconsistent-missing-override" ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-sign-conversion" ) +if (CUTLASS_CLANG_HOST_COMPILE) + + set(FLAGS_TO_ADD + "-Wno-error=implicit-int-conversion" + "-Wno-error=pass-failed" + "-Wno-error=inconsistent-missing-override" + "-Wno-sign-conversion" + "-Wno-unused-parameter" + ) + + foreach(FLAG ${FLAGS_TO_ADD}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}") + endforeach() + endif() if (NOT MSVC AND CUTLASS_NVCC_KEEP) @@ -396,9 +412,9 @@ endif() if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1) - if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) + if (CUTLASS_GNU_HOST_COMPILE OR CUTLASS_CLANG_HOST_COMPILE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c) - elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC")) + elseif(CUTLASS_MSVC_HOST_COMPILE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2) endif() endif() @@ -423,19 +439,8 @@ if (NOT CMAKE_BUILD_TYPE MATCHES "Release") list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo) endif() -#Report CUDA build flags -if (CUDA_COMPILER MATCHES "[Cc]lang") - if(CUTLASS_CUDA_CLANG_FLAGS) - message(STATUS "Using CLANG flags: ${CUTLASS_CUDA_CLANG_FLAGS}") - endif() -else() - if(CUTLASS_CUDA_NVCC_FLAGS) - message(STATUS "Using NVCC flags: ${CUTLASS_CUDA_NVCC_FLAGS}") - endif() -endif() - -if(CUDA_COMPILER MATCHES "[Cc]lang") - if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) +if (CUTLASS_CLANG_DEVICE_COMPILE) + if (NOT CUTLASS_CLANG_HOST_COMPILE) message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) endif() @@ -451,12 +456,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument) - string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION}) - list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR) - list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}) - # needed for libcublasLt.so in case it's installed in the same location as libcudart.so # dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags) # Otherwise linker uses RUNPATH and that does not propagate to loaded libs. @@ -464,11 +465,26 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") link_libraries(nvidia::cudart) link_libraries(nvidia::cuda_driver) + +endif() + +#Report CUDA build flags +if (CUTLASS_CLANG_DEVICE_COMPILE AND CUTLASS_CUDA_CLANG_FLAGS) + set(__FLAG_GROUP Clang) + set(__FLAG_LIST CUTLASS_CUDA_CLANG_FLAGS) +else(CUTLASS_NVCC_DEVICE_COMPILE AND CUTLASS_CUDA_NVCC_FLAGS) + set(__FLAG_GROUP NVCC) + set(__FLAG_LIST CUTLASS_CUDA_NVCC_FLAGS) endif() +set(__FLAG_DISPLAY_STRING "") +set(__FLAG_DISPLAY_SEPARATOR) +list(JOIN ${__FLAG_LIST} "\n " __FLAG_DISPLAY_STRING) +message(STATUS "Using the following ${__FLAG_GROUP} flags: \n ${__FLAG_DISPLAY_STRING}") + # Known gcc 8.1-8.3 SFINAE issue (fixed in gcc 8.4), check https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87748 # Also see https://github.com/NVIDIA/nccl/issues/835 for nvtx3.hpp -if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3) +if (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") endif() @@ -478,12 +494,10 @@ if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES " set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ") endif() -if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) - # CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this - # property for CMake 3.18+, so we request the NEW behavior for correct compatibility. - # https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 - cmake_policy(SET CMP0104 NEW) -endif() +# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this +# property for CMake 3.18+, so we request the NEW behavior for correct compatibility. +# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 +cmake_policy(SET CMP0104 NEW) if (MSVC) @@ -519,55 +533,21 @@ function(cutlass_apply_cuda_gencode_flags TARGET) set(ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS_ENABLED}) endif() - set(NVCC_FLAGS) - set(CLANG_FLAGS) set(__CMAKE_CUDA_ARCHS) foreach(ARCH ${ARCHS_ENABLED}) - list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH}) set(CODES) if(CUTLASS_NVCC_EMBED_CUBIN) - list(APPEND CODES sm_${ARCH}) list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real) endif() - if(CUTLASS_NVCC_EMBED_PTX) - list(APPEND CODES compute_${ARCH}) + if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE) + # If we're using clang for device compilation, the ptx is inserted + # via another command line option and the `-virtual` flags will cause an error. list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual) endif() list(JOIN CODES "," CODES_STR) - list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) endforeach() - if (NOT __SM_ARCHS) - if (CUDA_COMPILER MATCHES "[Cc]lang") - target_compile_options( - ${TARGET} - PRIVATE - $<$:${CLANG_FLAGS}> - ) - elseif(CMAKE_VERSION GREATER_EQUAL 3.18) - set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) - else() - target_compile_options( - ${TARGET} - PRIVATE - $<$:${NVCC_FLAGS}> - ) - endif() - else() - list(JOIN CLANG_FLAGS " " CLANG_FLAGS_STR) - list(JOIN NVCC_FLAGS " " STR_NVCC_FLAGS) - if (CUDA_COMPILER MATCHES "[Cc]lang") - if(${TARGET} MATCHES ".*\.cpp") - set_source_files_properties(${TARGET} PROPERTIES COMPILE_FLAGS ${CLANG_FLAGS_STR}) - endif() - elseif(CMAKE_VERSION GREATER_EQUAL 3.18) - set_source_files_properties(${TARGET} PROPERTIES CUDA_ARCHITECTURES ${STR_NVCC_FLAGS}) - else() - if(${TARGET} MATCHES ".*\.cu") - set_source_files_properties(${TARGET} PROPERTIES COMPILE_FLAGS ${STR_NVCC_FLAGS}) - endif() - endif() - endif() + set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) endfunction() @@ -588,8 +568,8 @@ set(__CUTLASS_CUDA_NVCC_FLAGS_DEBUG ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG} CACHE INTER function(cutlass_apply_standard_compile_options TARGET) - if(CUDA_COMPILER MATCHES "[Cc]lang") - set(CUDA_COMPILE_LANGUAGE CXX) + if(CUTLASS_CLANG_DEVICE_COMPILE) + set(CUDA_COMPILE_LANGUAGE CUDA) set(_FLAGS ${__CUTLASS_CUDA_FLAGS} ${__CUTLASS_CUDA_CLANG_FLAGS}) set(_FLAGS_RELEASE ${__CUTLASS_CUDA_FLAGS_RELEASE} ${__CUTLASS_CUDA_CLANG_FLAGS_RELEASE}) set(_FLAGS_RELWITHDEBINFO ${__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${__CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO}) @@ -682,8 +662,6 @@ target_include_directories( $ $ $ - $ - $ ) # Mark CTK headers as system to supress warnings from them @@ -825,7 +803,7 @@ function(cutlass_add_executable_tests NAME TARGET) # TEST_SETS_SUPPORTED: A list of test set names these tests support. # - set(options DISABLE_EXECUTABLE_INSTALL_RULE) + set(options DISABLE_EXECUTABLE_INSTALL_RULE DO_NOT_LOWERCASE_TEST_NAME) set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX) set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -915,11 +893,15 @@ function(cutlass_add_executable_tests NAME TARGET) foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS) if (CMD_COUNT GREATER 1) - string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TESTCASE_NAME) + set(TESTCASE_NAME "${NAME}_${CMD_OPTIONS_VAR}") else() - string(TOLOWER "${NAME}" TESTCASE_NAME) + set(TESTCASE_NAME "${NAME}") endif() + if (NOT __DO_NOT_LOWERCASE_TEST_NAME) + string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME) + endif() + # The following rigmarole is needed to deal with spaces and possible quotes in # command line arguments. The options are passed "by reference" as the actual # variable names holding the real options. We then expand these in a way that diff --git a/CUDA.cmake b/CUDA.cmake index 755b7476f2..7e91adb88d 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -26,49 +26,46 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(CUDA_COMPILER MATCHES "[Cc]lang") - set(CUTLASS_NATIVE_CUDA_INIT ON) -elseif(CMAKE_VERSION VERSION_LESS 3.12.4) - set(CUTLASS_NATIVE_CUDA_INIT OFF) -else() - set(CUTLASS_NATIVE_CUDA_INIT ON) +if (CUDA_COMPILER MATCHES "[Cc]lang") + message(WARNING "CUDA_COMPILER flag is deprecated, set CMAKE_CUDA_COMPILER to desired compiler executable.") + set(__CLANG_DEVICE_COMPILATION_REQUESTED ON) +elseif(CUDA_COMPILER) + message(WARNING "Deprecated flag CUDA_COMPILER used with unknown argument ${CUDA_COMPILER}, ignoring.") endif() -set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow") - -if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH}) - # For backward compatibility, allow use of CUDA_PATH. - set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc) +if (__CLANG_DEVICE_COMPILATION_REQUESTED AND NOT DEFINED CMAKE_CUDA_COMPILER) + set(CMAKE_CUDA_COMPILER clang++) # We will let the system find Clang or error out endif() -if(CUTLASS_NATIVE_CUDA) +enable_language(CUDA) +find_package(CUDAToolkit REQUIRED) - enable_language(CUDA) - - if(NOT CUDA_VERSION) - set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) - endif() - if(NOT CUDA_TOOLKIT_ROOT_DIR) - get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) - endif() +if(NOT CUDA_VERSION) + # For backward compatibility with older CMake code. + set(CUDA_VERSION ${CUDAToolkit_VERSION}) + set(CUDA_VERSION_MAJOR ${CUDAToolkit_VERSION_MAJOR}) + set(CUDA_VERSION_MINOR ${CUDAToolkit_VERSION_MINOR}) +endif() +if(NOT CUDA_TOOLKIT_ROOT_DIR) + # In some scenarios, such as clang device compilation, the toolkit root may not be set, so we + # force it here to the nvcc we found via the CUDAToolkit package. + get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_NVCC_EXECUTABLE}/../.." ABSOLUTE) +endif() +if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])") + set(CUTLASS_NVCC_DEVICE_COMPILE ON CACHE BOOL "Using nvcc tools for device compilation") +elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang") + set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation") else() + message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.") +endif() - find_package(CUDA REQUIRED) - # We workaround missing variables with the native flow by also finding the CUDA toolkit the old way. - - if(NOT CMAKE_CUDA_COMPILER_VERSION) - set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION}) - endif() - +if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30") + message(FATAL_ERROR "Clang device compilation for CUTLASS requires CMake 3.30 or higher.") endif() if (CUDA_VERSION VERSION_LESS 9.2) - message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.") -endif() -if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang") - set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) - message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + message(FATAL_ERROR "CUDA 9.2+ required, found ${CUDA_VERSION}.") endif() find_library( @@ -211,16 +208,6 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) # Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include # paths by default, so we add it explicitly here. -function(cutlass_correct_source_file_language_property) - if(CUDA_COMPILER MATCHES "[Cc]lang") - foreach(File ${ARGN}) - if(File MATCHES ".*\.cu$") - set_source_files_properties(${File} PROPERTIES LANGUAGE CXX) - endif() - endforeach() - endif() -endfunction() - if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all") set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON) else() @@ -306,18 +293,13 @@ function(cutlass_add_library NAME) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_library(${NAME} ${TARGET_SOURCE_ARGS} "") - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "") - endif() + + add_library(${NAME} ${TARGET_SOURCE_ARGS} "") cutlass_apply_standard_compile_options(${NAME}) + if (NOT __SKIP_GENCODE_FLAGS) - cutlass_apply_cuda_gencode_flags(${NAME}) + cutlass_apply_cuda_gencode_flags(${NAME}) endif() target_compile_features( @@ -359,13 +341,7 @@ function(cutlass_add_executable NAME) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_executable(${NAME} ${TARGET_SOURCE_ARGS}) - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS}) - endif() + add_executable(${NAME} ${TARGET_SOURCE_ARGS}) cutlass_apply_standard_compile_options(${NAME}) cutlass_apply_cuda_gencode_flags(${NAME}) @@ -388,7 +364,6 @@ function(cutlass_target_sources NAME) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) target_sources(${NAME} ${TARGET_SOURCE_ARGS}) endfunction() diff --git a/README.md b/README.md index efe47872c9..e61335f240 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. + # What's New in CUTLASS 3.6 CUTLASS 3.6.0 is an update to CUTLASS adding: diff --git a/examples/13_two_tensor_op_fusion/CMakeLists.txt b/examples/13_two_tensor_op_fusion/CMakeLists.txt index 0b1e2cdf87..6819a9766e 100644 --- a/examples/13_two_tensor_op_fusion/CMakeLists.txt +++ b/examples/13_two_tensor_op_fusion/CMakeLists.txt @@ -80,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE}) endforeach() - diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index 885aacebc1..55852730c2 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -59,11 +59,11 @@ // Also, we don't check the index value is legal and index array point is valid // for the sake of the performance. -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include diff --git a/examples/39_gemm_permute/layouts.h b/examples/39_gemm_permute/layouts.h index ffb27c4efa..3632ec0afb 100644 --- a/examples/39_gemm_permute/layouts.h +++ b/examples/39_gemm_permute/layouts.h @@ -33,11 +33,7 @@ computing reference permutations of 4/5D tensors when source data is column-major. */ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include "assert.h" -#endif #include "cutlass/cutlass.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" diff --git a/examples/41_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h index 90c0a69bd3..efca4f132d 100644 --- a/examples/41_fused_multi_head_attention/debug_utils.h +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -30,8 +30,8 @@ **************************************************************************************************/ #pragma once -#include -#include +#include +#include #include //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h index f8f06dfeab..e166af4de4 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -43,11 +43,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/aligned_buffer.h" #include "cutlass/array.h" @@ -57,12 +53,9 @@ #include "cutlass/layout/vector.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_coord.h" - #include "cutlass/gemm/gemm.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" - #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" #include "cutlass/numeric_types.h" diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h index 24530a0f2f..6860ee9e4c 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -43,11 +43,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/aligned_buffer.h" #include "cutlass/array.h" @@ -57,16 +53,12 @@ #include "cutlass/layout/vector.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_coord.h" - #include "cutlass/gemm/gemm.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" - #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" #include "cutlass/numeric_types.h" - #include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/epilogue/thread/scale_type.h" diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h index 22779b5901..5a2f928ad8 100644 --- a/examples/41_fused_multi_head_attention/fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -550,7 +550,7 @@ struct FMHAGrouped { auto prologueV = [&](int blockN) { typename MM1::Mma::IteratorB iterator_V( - typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], {problem_size_1_k, problem_size_1_n}, thread_id(), @@ -719,7 +719,7 @@ struct FMHAGrouped { } typename MM1::Mma::IteratorB iterator_V( - typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], {problem_size_1_k, problem_size_1_n}, thread_id(), @@ -761,15 +761,15 @@ struct FMHAGrouped { using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, output_accum_t, - kIsFirst, - kIsLast, + kIsFirst::value, + kIsLast::value, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< @@ -777,7 +777,7 @@ struct FMHAGrouped { typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: @@ -795,7 +795,7 @@ struct FMHAGrouped { int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = gemm_kernel_utils::call_conditional< - kIsLast, + kIsLast::value, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); @@ -817,8 +817,8 @@ struct FMHAGrouped { } if (kKeepOutputInRF) { - const bool kIsFirst = true; - const bool kIsLast = true; + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; using DefaultEpilogue = typename MM1::DefaultEpilogue; using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; using ElementCompute = typename DefaultOp::ElementCompute; diff --git a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h index 18e61f6fbc..a770e0b671 100644 --- a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -55,13 +55,14 @@ #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ { \ if (BOOL_V) { \ - constexpr bool BOOL_NAME = true; \ + using BOOL_NAME = std::true_type; \ F(); \ } else { \ - constexpr bool BOOL_NAME = false; \ + using BOOL_NAME = std::false_type; \ F(); \ } \ } + #define DISPATCH_ARCHTAG(CC, func) \ { \ if (CC >= 80) { \ diff --git a/examples/41_fused_multi_head_attention/kernel_backward.h b/examples/41_fused_multi_head_attention/kernel_backward.h index e7372f13e9..6fd94a6c58 100644 --- a/examples/41_fused_multi_head_attention/kernel_backward.h +++ b/examples/41_fused_multi_head_attention/kernel_backward.h @@ -32,6 +32,7 @@ #pragma once #include +#include #include #include @@ -85,8 +86,6 @@ #include "gemm/mma_from_smem.h" #include "transform/tile_smem_loader.h" -#include - using namespace gemm_kernel_utils; namespace { @@ -1956,7 +1955,8 @@ struct AttentionBackwardKernel { // no-op epilogue operator - just casting and storing contents of // accum to global memory - typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op( + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1}); typename MatmulDOIVJ::BiasGradEpilogue epilogue( shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); epilogue(output_op, output_iter, accum, output_iter); @@ -2211,7 +2211,7 @@ struct AttentionBackwardKernel { incrIteration(p, query_start, key_start, next_query, next_key); DISPATCH_BOOL( next_key != key_start, kForceReloadK, ([&]() { - prologueQkNextIteration( + prologueQkNextIteration( shared_storage, p, next_query, next_key, warp_id, lane_id); })); } @@ -2342,7 +2342,7 @@ struct AttentionBackwardKernel { thread_id, cutlass::MatrixCoord{0, 0}); - MatmulQK::Mma::prologue( + MatmulQK::Mma::template prologue( shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), iterator_A, @@ -2369,6 +2369,7 @@ struct AttentionBackwardKernel { p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, thread_id); + accumulateInGmem( shared_storage.gradV_epilogue_final(), output_frags.gradV, @@ -2406,7 +2407,7 @@ struct AttentionBackwardKernel { int thread_id = 32 * warp_id + lane_id; DISPATCH_BOOL( first, kIsFirst, ([&]() { - static constexpr auto ScaleType = kIsFirst + static constexpr auto ScaleType = kIsFirst::value ? cutlass::epilogue::thread::ScaleType::Nothing : cutlass::epilogue::thread::ScaleType::NoBetaScaling; using EpilogueOutputOp = diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index 4c80f549f3..71d79415e9 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -38,6 +38,7 @@ #include #include +#include #include #include "cutlass/fast_math.h" @@ -71,8 +72,6 @@ #include "gemm_kernel_utils.h" #include "transform/tile_smem_loader.h" -#include - using namespace gemm_kernel_utils; namespace { @@ -1036,15 +1035,15 @@ struct AttentionKernel { using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, ElementCompute, - kIsFirst, - kIsLast, + kIsFirst::value, + kIsLast::value, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< @@ -1052,7 +1051,7 @@ struct AttentionKernel { typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: @@ -1070,7 +1069,7 @@ struct AttentionKernel { int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = call_conditional< - kIsLast, + kIsLast::value, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h index afa20871e1..1acb4a2de6 100644 --- a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -39,11 +39,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -53,12 +49,9 @@ #include "cutlass/tensor_coord.h" #include "cutlass/aligned_buffer.h" #include "cutlass/functional.h" - #include "cutlass/gemm/gemm.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" - #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py index a1f2998c4f..6474d95c5d 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py @@ -43,7 +43,7 @@ def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir def gen_cpp_sample(self): code = "/* Auto Generated code - Do not edit.*/\n" - code += "#include \n" + code += "#include \n" code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n" code += "#include \"cutlass/cutlass.h\" \n" diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py index 117179ed2a..db1ec4c72f 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py @@ -380,7 +380,7 @@ def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir def gen_CUTLASS_irrelevant_API(self): code = "" code += "#include \n" - code += "#include \n" + code += "#include \n" param_name = "Fused" + str(self.b2b_num) + "xGemm_" for i in range(self.b2b_num): diff --git a/examples/45_dual_gemm/test_run.h b/examples/45_dual_gemm/test_run.h index 2bd6c720a4..4a58a3a16c 100644 --- a/examples/45_dual_gemm/test_run.h +++ b/examples/45_dual_gemm/test_run.h @@ -66,7 +66,7 @@ int testRun(int arch, std::vector & test_funcs, const std::string & return -1; } - if (!(props.major == arch_major && props.minor == arch_minor)) { + if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) { supported = false; } diff --git a/examples/45_dual_gemm/threadblock/dual_epilogue.h b/examples/45_dual_gemm/threadblock/dual_epilogue.h index cd2288af6f..3ef1c6d33c 100644 --- a/examples/45_dual_gemm/threadblock/dual_epilogue.h +++ b/examples/45_dual_gemm/threadblock/dual_epilogue.h @@ -38,11 +38,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index 884f3535d0..0a74e02a83 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -45,18 +45,18 @@ and BEFORE scatter operations are applied. */ -#include -#include -#include -#include -#include -#include - +#include +#include +#include +#include +#include #include #include #include #include +#include + #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -64,7 +64,6 @@ #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" - #include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu index c1d3caeba6..ab82b40cca 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -619,7 +619,6 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // @@ -681,4 +680,4 @@ int main(int argc, char const **args) { return 0; } -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index 5fc96d4eed..40fa689489 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -559,4 +559,4 @@ int main(int argc, char const **args) { return 0; } -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 8d22c75f84..48eca35c2d 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -36,4 +36,4 @@ We are currently optimizing the following cases: * Optimizations for memory bound cases. -* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. \ No newline at end of file +* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp index 7f7011265b..02e257c3fe 100644 --- a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -207,4 +207,4 @@ bool initialize_packed_scale( return false; } return true; -} \ No newline at end of file +} diff --git a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp index 97df3b8784..de5a3d3fd0 100644 --- a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp @@ -159,4 +159,4 @@ void reorder_tensor( cutlass::DeviceAllocation temp(size(layout_src)); reorder_tensor(data, layout_src, temp.get(), layout_dst); cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); -} \ No newline at end of file +} diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index d57e1deea5..7b20a33548 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -63,7 +63,7 @@ #include #include #include -#include +#include #include "cutlass/cutlass.h" diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp index 57365a8b36..bfb64820f0 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -35,9 +35,36 @@ #include "dispatch_policy_extra.hpp" #include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" +#include "../pipeline/prefetch_pipeline_sm90.hpp" namespace cutlass::gemm::collective { +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size + constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); + constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes; +} + +} // namespace detail + // GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch template < class ElementA, @@ -98,7 +125,7 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; @@ -184,7 +211,7 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp index 710224d78c..9bcb1f5a7e 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -57,6 +57,19 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +constexpr int PrefetchStages = 4; +constexpr int PrefetchInitialStages = 1; +// This determines how much shmem we set aside for prefetch. +// We don't reuse anything loaded by prefetcher, so we can keep +// loading into the same place -- there will be a conflict when +// writing, but it doesn't affect performance as much as the doors +// that this opens. +constexpr int PrefetchStagesActual = 1; + +} // namespace detail + // WarpSpecialized Mainloop template < int Stages, @@ -117,15 +130,7 @@ struct CollectiveMma< static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - static constexpr int PrefetchStages = 4; - static constexpr int PrefetchInitialStages = 1; - // This determines how much shmem we set aside for prefetch. - // We don't reuse anything loaded by prefetcher, so we can keep - // loading into the same place -- there will be a conflict when - // writing, but it doesn't affect performance as much as the doors - // that this opens. - static constexpr int PrefetchStagesActual = 1; - using PrefetcherPipeline = cutlass::PrefetchPipeline; + using PrefetcherPipeline = cutlass::PrefetchPipeline; using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; @@ -155,7 +160,7 @@ struct CollectiveMma< using PrefetchSmemLayoutA = decltype(make_layout(make_shape( cute::Int(SmemLayoutA{})>{}, cute::Int(SmemLayoutA{})>{}, - cute::Int{}))); + cute::Int{}))); static constexpr auto prefetch_smem_size = cute::cosize_v; @@ -176,7 +181,7 @@ struct CollectiveMma< using InternalElementB = cute::conditional_t>>; // Defined outside the class where it's used, to work around MSVC issues - using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _0> { @@ -660,7 +665,7 @@ struct CollectiveMma< bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); - prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages); + prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages); Tensor sA = make_tensor( make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) @@ -702,7 +707,7 @@ struct CollectiveMma< break; } - prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages); + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages); using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); diff --git a/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt b/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt new file mode 100644 index 0000000000..183202593c --- /dev/null +++ b/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 64_ada_fp8_gemm_grouped + ada_fp8_gemm_grouped.cu + ) diff --git a/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu new file mode 100644 index 0000000000..8e3dbbb08b --- /dev/null +++ b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu @@ -0,0 +1,1208 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Ada FP8 GEMM Grouped With Per-Group Scale Example. + + This workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices + in Global Memory are passed to the kernel in array (also held in Global Memory). Similarly, + leading dimensions and problem sizes are stored in arrays in GMEM. + + This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM + concept may be distinct. + + The differences between this and the examples/24_gemm_grouped are: (1) this example scales the output of each GEMM by a different scalar value specified by alpha_ptr_array. (2) this example uses FP8 tensorcore. + + This benchmark program initializes a workspace with random problem sizes for a given number of + groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to + model problems more similar to the traditional batched GEMM. + + Additionally, problem sizes are collected and binned to compute the same problem as a series of + conventional batched GEMMs (setup for this problem is not timed). This demonstrates the performance + enhancement achieved by implementing a specialized grouped GEMM kernel. + + Examples: + + # Runs a grouped GEMM with 100 random problem sizes + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 + + # Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024) + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --k=1024 --verbose=true + + # Runs a grouped GEMM that is equivalent to a batched GEMM + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true + + # Execute Grouped GEMM and profile with NSight + $ nv-nsight-cu-cli ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --m=256 --n=256 --k=256 --verbose=true \ + --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double initialization_time_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double initialization_time_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for cutlass::gemm::GemmCoord +struct HashGemmCoord { + size_t operator()(cutlass::gemm::GemmCoord const &problem) const { + std::hash hasher; + return (hasher(problem.m() * 3)) ^ (hasher(1 + problem.n() * 5)) ^ (hasher(2 + problem.k() * 7)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool profile_initialization; + bool sort_problems; + + std::vector problem_sizes; + + // problem size bins + std::unordered_map< + cutlass::gemm::GemmCoord, + std::vector, + HashGemmCoord> problem_bins; + + int alignment; + int problem_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + std::vector alpha_array; + float beta; + std::string benchmark_path; + + std::string output_tag; + std::ofstream output_file; + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(16), + reference_check(true), + profile_initialization(false), + sort_problems(false), + problem_count(15), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 16); + cmd.get_cmd_line_argument("groups", problem_count, 15); + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + + std::string output_path; + cmd.get_cmd_line_argument("tag", output_tag); + cmd.get_cmd_line_argument("output_file", output_path); + + if (!output_path.empty()) { + + std::ios_base::openmode open_mode = std::ios_base::out; + + std::ifstream input_file(output_path.c_str()); + + if (input_file.good()) { + open_mode = std::ios_base::app; + input_file.close(); + } + + output_file.open(output_path.c_str(), open_mode); + + if (output_file.good() && open_mode != std::ios_base::app) { + output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; + } + } + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + error = true; + problem_sizes.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + // Post-process the problem sizes + bin_problems(); + + // Initalize alpha array + randomize_alpha_ptr_array(cmd); + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_n = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + + if (m < 1) { + m = alignment * ((rand() % 256) + 1); + } + + if (n < 1) { + n = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + cutlass::gemm::GemmCoord problem(m, n, k); + + problem_sizes.push_back(problem); + } + } + + void randomize_alpha_ptr_array(cutlass::CommandLine &cmd) { + alpha_array.resize(problem_count); + for (int i = 0; i < problem_count; ++i) { + alpha_array[i] = static_cast((rand() % 100) - 50 + alpha); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes.push_back(extent); + } + } + + return true; + } + + /// Post processes the problems + void bin_problems() { + + problem_bins.clear(); + + problem_count = int(problem_sizes.size()); + + // + // Insert the problem sizes into a sorted container class. This is *NOT* necessary + // to run the CUTLASS kernel, but it enables the execution of cublas's batched GEMM. + // + for (int i = 0; i < int(problem_sizes.size()); ++i) { + auto it = problem_bins.find(problem_sizes.at(i)); + if (it == problem_bins.end()) { + problem_bins.insert({problem_sizes.at(i), std::vector({i}) }); + } + else { + it->second.push_back(i); + } + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "64_ada_fp8_gemm_grouped\n\n" + << " This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM\n" + << " in that multiple, independent GEMMs are computed by one grid launch. It differs in that each\n" + << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" + << " in device Global Memory and loaded by the kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual GEMM problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped GEMM with 100 random problem sizes\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100\n\n" + + << "# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" + + << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" + << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" + << "#\n" + << "# For example, assume the following are the contents of 'problems.txt'\n" + << "#\n" + << "# 0 1024x256x520\n" + << "# 1 520x264x1024\n" + << "# 2 96x48x1024\n" + << "#\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --benchmark=problems.txt\n\n" + + << "# Execute Grouped GEMM and profile with NSight\n" + << "$ nv-nsight-cu-cli ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --m=256 --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + for (auto const & problem : problem_sizes) { + fmas += problem.product(); + } + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BaseTestbed { +public: + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + std::vector alpha_ptr_array_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation alpha_array_device; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation alpha_ptr_array_device; + + BaseTestbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + int problem_count() const { + return options.problem_count; + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = static_cast(5); + scope_min = static_cast(-5); + } + else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Allocates device-side data + void allocate() { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + lda_host.resize(problem_count()); + ldb_host.resize(problem_count()); + ldc_host.resize(problem_count()); + ldd_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem = options.problem_sizes.at(i); + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + lda.reset(problem_count()); + ldb.reset(problem_count()); + ldc.reset(problem_count()); + ldd.reset(problem_count()); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + alpha_ptr_array_host.resize(problem_count()); + alpha_array_device.reset(problem_count()); + alpha_ptr_array_device.reset(problem_count()); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(problem_count()); + std::vector ptr_B_host(problem_count()); + std::vector ptr_C_host(problem_count()); + std::vector ptr_D_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count()); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count()); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count()); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count()); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), block_D.size(), ElementC(), ElementC()); + + // Initialize alpha array + alpha_array_device.copy_from_host(options.alpha_array.data()); + for (int32_t i = 0; i < problem_count(); ++i) { + alpha_ptr_array_host.at(i) = alpha_array_device.get() + i; + } + alpha_ptr_array_device.copy_from_host(alpha_ptr_array_host.data()); + } + + /// Verifies the result is a GEMM + bool verify() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); + + cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + options.alpha_array[i], + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + options.beta, + view_C, + view_Ref_device, + ElementAccumulator(0) + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); + + cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + return passed; + } + } + + return passed; + } + +}; + +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine GEMM with different GroupScheduleMode_ + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGroupedPerGroupScale< + typename Gemm_::ElementA, + typename Gemm_::LayoutA, + Gemm_::kTransformA, + Gemm_::kAlignmentA, + typename Gemm_::ElementB, + typename Gemm_::LayoutB, + Gemm_::kTransformB, + Gemm_::kAlignmentB, + typename Gemm_::ElementC, + typename Gemm_::LayoutC, + typename Gemm_::ElementAccumulator, + typename Gemm_::OperatorClass, + typename Gemm_::ArchTag, + typename Gemm_::ThreadblockShape, + typename Gemm_::WarpShape, + typename Gemm_::InstructionShape, + typename Gemm_::EpilogueOutputOp, + typename Gemm_::ThreadblockSwizzle, + Gemm_::kStages, + GroupScheduleMode_>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + std::cout << std::endl; + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Gemm::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Gemm::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + + std::cout << std::endl; + std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the GEMM arguments + typename Gemm::EpilogueOutputOp::ElementCompute ** alpha_ptr_array = this->alpha_ptr_array_device.get(); + typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha_ptr_array, nullptr); + + // Configure GEMM arguments + typename Gemm::Arguments args( + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = gemm.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Run the grouped GEMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped GEMM object + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + gemm.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + gemm.initialize(args, workspace.get()); + } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); + } + + int64_t total_tiles = Gemm::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; + + std::cout << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + std::cout << "\nPassed\n"; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { + std::cerr << "This example requires CUDA 12.4 or greater." << std::endl; + return 0; + } + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() failed with error: " << cudaGetErrorString(result) << std::endl; + return 0; + } + + if (!(properties.major == 8 && properties.minor == 9)) { + std::cerr << "CUTLASS's Ada FP8 Gemm Grouped example requires a device of compute capability 89.\n" << std::endl; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the Grouped and Batched GEMM types + // + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + constexpr int ElementsPerAccessB = 128 / cutlass::sizeof_bits::value; + + // Define a grouped GEMM kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + ElementsPerAccessA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + ElementsPerAccessB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 4>::GemmKernel; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + // + // Profile it + // + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } + + if (result.error != cudaSuccess) { + return 1; + } + + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6486d71435..7e8d45227b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -143,8 +143,10 @@ foreach(EXAMPLE 61_hopper_gemm_with_topk_and_softmax 62_hopper_sparse_gemm 63_hopper_gemm_with_weight_prefetch + 64_ada_fp8_gemm_grouped ) add_subdirectory(${EXAMPLE}) endforeach() + diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu index 87ad873ce6..a8ae3b1040 100644 --- a/examples/cute/tutorial/tiled_copy.cu +++ b/examples/cute/tutorial/tiled_copy.cu @@ -95,36 +95,17 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) /// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation /// has the precondition that pointers are aligned to the vector size. /// -template -__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout) +template +__global__ void copy_kernel_vectorized(TensorS S, TensorD D, Tiled_Copy tiled_copy) { using namespace cute; - using Element = typename TensorS::value_type; // Slice the tensors to obtain a view into each tile. Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) - // Define `AccessType` which controls the size of the actual memory access. - using AccessType = cutlass::AlignedArray; - - // A copy atom corresponds to one hardware memory access. - using Atom = Copy_Atom, Element>; - - // Construct tiled copy, a tiling of copy atoms. - // - // Note, this assumes the vector and thread layouts are aligned with contigous data - // in GMEM. Alternative thread layouts are possible but may result in uncoalesced - // reads. Alternative vector layouts are also possible, though incompatible layouts - // will result in compile time errors. - auto tiled_copy = - make_tiled_copy( - Atom{}, // access size - ThreadLayout{}, // thread layout - VecLayout{}); // vector layout (e.g. 4x1) - // Construct a Tensor corresponding to each thread's slice. - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x); Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN) Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN) @@ -198,11 +179,34 @@ int main(int argc, char** argv) Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n') Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n') + // Construct a TiledCopy with a specific access pattern. + // This version uses a + // (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc), + // (2) Layout-of-Values that each thread will access. + // Thread arrangement - Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); + Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); // (32,8) -> thr_idx - // Vector dimensions - Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); + // Value arrangement per thread + Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx + + // Define `AccessType` which controls the size of the actual memory access instruction. + using CopyOp = UniversalCopy>; // A very specific access width copy instruction + //using CopyOp = UniversalCopy>; // A more generic type that supports many copy strategies + //using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs + + // A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element. + using Atom = Copy_Atom; + + // Construct tiled copy, a tiling of copy atoms. + // + // Note, this assumes the vector and thread layouts are aligned with contigous data + // in GMEM. Alternative thread layouts are possible but may result in uncoalesced + // reads. Alternative value layouts are also possible, though incompatible layouts + // will result in compile time errors. + TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy + thr_layout, // thread layout (e.g. 32x4 Col-Major) + val_layout); // value layout (e.g. 4x1) // // Determine grid and block dimensions @@ -217,8 +221,7 @@ int main(int argc, char** argv) copy_kernel_vectorized<<< gridDim, blockDim >>>( tiled_tensor_S, tiled_tensor_D, - thr_layout, - vec_layout); + tiled_copy); cudaError result = cudaDeviceSynchronize(); if (result != cudaSuccess) { diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index 9d080116da..c9e02245e2 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -51,19 +51,14 @@ naive_cooperative_copy(uint32_t const& tid, Tensor const& src, Tensor & dst) { - auto N = size(src); - if (tid < N) { - uint32_t upper_bound = (N / NumThreads) * NumThreads; - CUTE_UNROLL - for (uint32_t i = 0; i < upper_bound; i += NumThreads) { // All in-bounds - dst[tid + i] = src[tid + i]; - } - if (N % NumThreads != 0) { // Likely static condition - uint32_t final_idx = tid + upper_bound; - if (final_idx < N) { // Final in-bounds - dst[final_idx] = src[final_idx]; - } - } + auto N = size(dst); + auto R = N % Int{}; + if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds + dst[tid] = src[tid]; + } + CUTE_UNROLL + for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds + dst[tid + i] = src[tid + i]; } } @@ -117,12 +112,14 @@ heuristic_permutation(Tensor const& a, // template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor & dst) + Tensor & dst, + CopyPolicy const& cpy = {}) { // Assumes the shapes are static, can generalize/fallback CUTE_STATIC_ASSERT_V(is_static{} && is_static{}); @@ -283,23 +280,28 @@ cooperative_copy(uint32_t const& tid, // If we're using all threads (static) or the tid is in-range (dynamic) if (vec_thrs == NumThreads or tid < vec_thrs) { - return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + auto src_c = recast(src_v); + auto dst_c = recast(dst_v); + return copy(cpy, src_c, dst_c); } } } + // Default max-vectorization size to value_type size template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor & dst) + Tensor & dst, + CopyPolicy const& cpy = {}) { constexpr uint32_t MaxVecBits = sizeof_bits_v; - return cooperative_copy(tid, src, dst); + return cooperative_copy(tid, src, dst, cpy); } // @@ -308,26 +310,30 @@ cooperative_copy(uint32_t const& tid, template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor && dst) + Tensor && dst, + CopyPolicy const& cpy = {}) { - return cooperative_copy(tid, src, dst); + return cooperative_copy(tid, src, dst, cpy); } template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor && dst) + Tensor && dst, + CopyPolicy const& cpy = {}) { - return cooperative_copy(tid, src, dst); + return cooperative_copy(tid, src, dst, cpy); } } // end namespace cute diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index 2c91ce6f45..e4bd5ea628 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -50,31 +50,115 @@ namespace cute namespace detail { +// Slow fallback path: +template +CUTE_HOST_DEVICE +void +epilogue_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename ThrMMA::ValTypeC; + CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v); + + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + // Custom axpby_if for now + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) + { + if (elem_less(tCcC(i), shape(sC))) + { + tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i) + : alpha * tCrC(i) + + beta * static_cast(sC_load_op(tCsC(i)))); + } + } +} + +template +CUTE_HOST_DEVICE +void +epilogue_no_predication(Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C + SmemCopyOpC const& sC_copy_op) +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename TRC::value_type; + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + Tensor tCrDi = make_fragment_like(tCsC); + Tensor tCrD = make_fragment_like(tCrC); + if(!isBetaZero) { + copy(sC_copy_op, tCsC, tCrDi); + // Transform C on/after load + cute::transform(tCrDi, tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, tCrDi, sC_store_op); + copy(sC_copy_op, tCrDi, tCsC); +} + // Predicated Cooperative GEMM template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class TA, class ALayout, class TB, class BLayout, + class TC, class RCLayout, + class ALoadTransformOp, class BLoadTransformOp> CUTE_HOST_DEVICE void -cooperative_gemm_predication(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +cooperative_gemm_predication(ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM { - using TypeA = typename TA::value_type; - using TypeB = typename TB::value_type; - using TypeC = typename TC::value_type; + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; // // MMA Partitioning @@ -83,22 +167,18 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, // Partition the sA, sB, and sC tiles across the threads for the MMA Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) - Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) - Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) #if 0 if (thread0()) { print(" sA: "); print( sA); print("\n"); print(" sB: "); print( sB); print("\n"); - print(" sC: "); print( sC); print("\n"); print(thr_mma); print("tCsA: "); print(tCsA); print("\n"); print("tCsB: "); print(tCsB); print("\n"); - print("tCsC: "); print(tCsC); print("\n"); print("tCrA: "); print(tCrA); print("\n"); print("tCrB: "); print(tCrB); print("\n"); print("tCrC: "); print(tCrC); print("\n"); @@ -154,23 +234,20 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I - tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{}; } } CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I - tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{}; } } // // MAINLOOP // - // Clear accumulators - clear(tCrC); - CUTE_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { @@ -185,138 +262,80 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I - tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{}; } } CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I - tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{}; } } } // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } - - // - // Epilogue - // - - // Create coordinate tensors for the problem - Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) - // Repeat partitioning with thr_mma - Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) - - const bool isBetaZero = (beta == Beta{}); - - // Custom axpby_if for now - CUTE_UNROLL - for (int i = 0; i < size(tCrC); ++i) - { - if (elem_less(tCcC(i), shape(sC))) - { - tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast(tCrC(i)) - : alpha * static_cast(tCrC(i)) + - beta * static_cast(sC_load_op(tCsC(i)))); - } - } -} - -// Slow fallback path -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> -CUTE_HOST_DEVICE -void -cooperative_gemm_predication(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C -{ - // ThrMMA - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op); } // Unpredicated Cooperative GEMM -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class SmemCopyOpA, class SmemCopyOpB> CUTE_HOST_DEVICE void -cooperative_gemm_no_predication(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +cooperative_gemm_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op, + SmemCopyOpB const& sB_copy_op) { - using TypeA = typename TA::value_type; - using TypeB = typename TB::value_type; - using TypeC = typename TC::value_type; + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; - // ThrMMA - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); // // MMA Partitioning // - Tensor tCsC = thr_mma.partition_C(sC); // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) - Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) using CopyOpAType = SmemCopyOpA; using CopyOpBType = SmemCopyOpB; - auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); Tensor tCsA = smem_thr_copy_A.partition_S(sA); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + Tensor tCrAi = make_fragment_like(tCsA); + Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K - auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); Tensor tCsB = smem_thr_copy_B.partition_S(sB); - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + Tensor tCrBi = make_fragment_like(tCsB); + Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K #if 0 if (thread0()) { print(" sA: "); print(sA); print("\n"); print(" sB: "); print(sB); print("\n"); - print(" sC: "); print(sC); print("\n"); print(thr_mma); print("\n"); - print("tCsC: "); print(tCsC); print("\n"); print("tCrA: "); print(tCrA); print("\n"); print("tCrB: "); print(tCrB); print("\n"); print("tCrC: "); print(tCrC); print("\n"); @@ -333,15 +352,12 @@ cooperative_gemm_no_predication(uint32_t thread_idx, // PREFETCH // - copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); - copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{})); // // MAINLOOP // - // Clear accumulators - clear(tCrC); - constexpr int K_BLOCK_MAX = size<2>(tCrA); CUTE_UNROLL @@ -352,132 +368,178 @@ cooperative_gemm_no_predication(uint32_t thread_idx, { // Load the next k_block int k_next = k_block + 1; // statically unrolled - copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next)); - copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next)); + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next)); } // Transform A and B, relying on the compiler to remove in case of identity ops - cute::transform(tCrA(_,_,k_block), sA_load_op); - cute::transform(tCrB(_,_,k_block), sB_load_op); + cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op); // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } - - // - // Epilogue - // - - auto isBetaZero = [&] () { - if constexpr (is_complex::value) { - return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; - } - else { - return beta == Int<0>{}; - } - CUTE_GCC_UNREACHABLE; - } (); - - using CopyOpCType = SmemCopyOpC; - Tensor tCrD = thr_mma.make_fragment_C(tCsC); - if(!isBetaZero) { - copy(CopyOpCType{}, tCsC, tCrD); - // Transform C on/after load - cute::transform(tCrD, sC_load_op); - } - // C = alpha * (A * B) + beta * C - axpby(alpha, tCrC, beta, tCrD); - // Transform C before/on store - cute::transform(tCrD, sC_store_op); - copy(CopyOpCType{}, tCrD, tCsC); } } // end namespace detail -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy, + class SmemCopyOpC = DefaultCopy> CUTE_HOST_DEVICE void -cooperative_gemm(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyOpC const& sC_copy_op = {}) { + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK - using TypeA = typename TA::value_type; - using TypeB = typename TB::value_type; - using TypeC = typename TC::value_type; - - static_assert(is_convertible_v>, TypeA>, - "ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type"); - static_assert(is_convertible_v>, TypeB>, - "BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type"); - static_assert(is_convertible_v>, TypeC>, - "CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); - static_assert(is_convertible_v>, TypeC>, - "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); - - static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), - tile_shape(TiledMMA{})); - if constexpr (compat) { - detail::cooperative_gemm_no_predication( - thread_idx, tiled_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC + + // Clear accumulators + clear(tCrC); + +#if 0 + if (thread0()) { + print(" sC: "); print(sC); print("\n"); + print(" tCsC: "); print(tCsC); print("\n"); + } +#endif + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + detail::epilogue_no_predication( + alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); + } +} + +// C already partitioned into registers on input +// It can be passed non-empty +// Epilogue not included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + // Check if input C fragment is compatible with thr_mma and problem size + using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB)))); + CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC))); + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op ); } else { detail::cooperative_gemm_predication( - thread_idx, tiled_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op ); } } +// Accept mutable temporaries template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy, + class SmemCopyOpC = DefaultCopy> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor && sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyOpC const& sC_copy_op = {}) { - using CopyOpA = AutoVectorizingCopyWithAssumedAlignment>; - using CopyOpB = AutoVectorizingCopyWithAssumedAlignment>; - using CopyOpC = AutoVectorizingCopyWithAssumedAlignment>; - cooperative_gemm( - thread_idx, tiled_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op - ); + cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op, + sA_copy_op, sB_copy_op, sC_copy_op); } // Legacy overload of cute::gemm for backwards-compatibility @@ -485,27 +547,38 @@ template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity> CUTE_HOST_DEVICE void -gemm(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + // Goes directly to the slow path to avoid getting thread_idx from thr_mma detail::cooperative_gemm_predication( - thr_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op + thr_mma, sA, sB, sC, sA_load_op, sB_load_op + ); + + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op ); } diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index c2decd15d7..84ef49161d 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -38,79 +38,6 @@ namespace cute { -// -// Accept mutable temporaries -// - -template -CUTE_HOST_DEVICE -void -copy(Tensor const& src, - Tensor && dst) -{ - return copy(src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_vec(Tensor const& src, - Tensor && dst) -{ - return copy_vec(src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_aligned(Tensor const& src, - Tensor && dst) -{ - return copy_aligned(src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_if(PrdTensor const& pred, - Tensor const& src, - Tensor && dst) -{ - return copy_if(pred, src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_if(CopyPolicy const& copy_policy, - PrdTensor const& pred, - Tensor const& src, - Tensor && dst) -{ - return copy_if(copy_policy, pred, src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy(CopyPolicy const& copy_policy, - Tensor const& src, - Tensor && dst) -{ - return copy(copy_policy, src, dst); -} - // // copy_if -- Predicated Copy // @@ -124,12 +51,13 @@ copy_if(PrdTensor const& pred, Tensor const& src, Tensor & dst) { - auto copy_op = select_elementwise_copy(src, dst); + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; CUTE_UNROLL - for (int i = 0; i < size(src); ++i) { + for (int i = 0; i < size(dst); ++i) { if (pred(i)) { - copy_op.copy(src(i), dst(i)); + dst(i) = static_cast(static_cast(src(i))); } } } @@ -138,17 +66,6 @@ copy_if(PrdTensor const& pred, // copy_if -- Predicated CopyAtom // -namespace detail { - -// Trait that detects if atom's traits has a member function with(bool) -template -constexpr bool has_with_bool = false; - -template -constexpr bool has_with_bool().with(declval()))>> = true; - -} // end namespace detail - template const& copy_atom, Tensor & dst) // (V,Rest...) { static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy - copy_atom.call(src, dst); + if constexpr (has_with_bool) { + copy_atom.with(pred()).call(src, dst); + } else { + if (pred()) { copy_atom.call(src, dst); } + } } else { // Loop over all but the first mode constexpr int R = SrcLayout::rank; Tensor src_v = group_modes<1,R>(src); Tensor dst_v = group_modes<1,R>(dst); CUTE_UNROLL - for (int i = 0; i < size<1>(src_v); ++i) { - // If copy traits can be transformed with a predicate value, do it, otherwise branch here - if constexpr (detail::has_with_bool>) { + for (int i = 0; i < size<1>(dst_v); ++i) { + if constexpr (has_with_bool) { copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i)); } else { - if (pred(i)) { - copy_atom.call(src_v(_,i), dst_v(_,i)); - } + if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); } } } } } // -// copy_vec -- attempt vectorized copy with VecType +// copy_if -- AutoCopyAsync // - -template CUTE_HOST_DEVICE void -copy_vec(Tensor const& src, - Tensor & dst) +copy_if(AutoCopyAsync const& cpy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) { - static_assert(sizeof_bits_v >= 8 && sizeof_bits_v % 8 == 0, - "Expected a vectorization type of at least a byte."); + using SrcElemWithConst = remove_reference_t; using SrcType = typename SrcEngine::value_type; using DstType = typename DstEngine::value_type; - if constexpr (cute::is_same::value && - sizeof_bits_v > sizeof_bits_v) - { - // Preserve volatility of Src/Dst types. - using SrcVecType = conditional_t, VecType const volatile, VecType const>; - using DstVecType = conditional_t, VecType volatile, VecType >; - Tensor src_v = recast(src); - Tensor dst_v = recast(dst); -#if 0 - if (thread0()) { - print("copy_vec<%db> -- vectorizing copy:\n", int(sizeof_bits_v)); - print(" "); print(src); print(" => "); print(src_v); print("\n"); - print(" "); print(dst); print(" => "); print(dst_v); print("\n"); + auto copy_op = []() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType)) { + if constexpr (is_const_v && sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEGLOBAL{}; + } else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + } else { + return UniversalCopy{}; } -#endif - return copy_if(TrivialPredTensor{}, src_v, dst_v); - } else { -#if 0 - if (thread0()) { - print("copy_vec<%db> -- NOT vectorizing copy:\n", int(sizeof_bits_v)); - print(" "); print(src); print("\n"); - print(" "); print(dst); print("\n"); - } + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; #endif + }(); - return copy_if(TrivialPredTensor{}, src, dst); + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } } } +// +// copy -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy(AutoCopyAsync const& cpy, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + copy_if(cpy, TrivialPredTensor{}, src, dst); +} + // // copy -- CopyAtom // @@ -238,15 +172,56 @@ template const& copy_atom, - Tensor const& src, - Tensor & dst) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) { - return copy_if(copy_atom, TrivialPredTensor{}, src, dst); + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + + if constexpr (is_static::value && is_static::value) { + CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v)); + + // AutoFilter on the Rest-mode + auto dst_null = nullspace(layout<1>(dst_v)); + + Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest)) + Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest)) + + CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n)); + CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error"); + CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy"); + CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{})); + CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{})); + + Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + + CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c)); + CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); + CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); + + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_c); ++i) { + copy_atom.call(src_c(_,i), dst_c(_,i)); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } } -////////////////////////////////////////// -// Special Auto-Vectorizing Overloads -////////////////////////////////////////// +//////////////////////////////////////////////////////// +// Special Auto-Vectorizing, Auto-Filtering Overloads // +//////////////////////////////////////////////////////// // Specialization for AutoVectorizingCopyAssumedAlignment template const&, Tensor const& src, Tensor & dst) { - constexpr int vec_elem = decltype(max_common_vector(src, dst))::value; - - constexpr int max_align_src = decltype(max_alignment(src.layout()))::value; - constexpr int max_align_dst = decltype(max_alignment(dst.layout()))::value; - constexpr int max_align = gcd(vec_elem, max_align_src, max_align_dst); + constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); + constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); + static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit copy!"); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + + if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) { + // If more than one element vectorizes to 8bits or more, then recast and copy + using VecType = uint_bit_t; + // Preserve volatility + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; - constexpr int src_bits = sizeof_bits::value; - constexpr int vec_bits = gcd(src_bits * max_align, MaxVecBits); + // Recast + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); - if constexpr (vec_elem > 1 && vec_bits >= 8) { - // If more than one element vectorizes to 8bits or more, then copy_vec #if 0 if (thread0()) { - print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits); - print(" "); print(src); print("\n"); - print(" "); print(dst); print("\n"); + print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits); + print(" "); print(src); print(" => "); print(src_v); print("\n"); + print(" "); print(dst); print(" => "); print(dst_v); print("\n"); } #endif - return copy_vec>(src, dst); + + return copy_if(TrivialPredTensor{}, src_v, dst_v); } else { return copy_if(TrivialPredTensor{}, src, dst); } } +template +struct AutoFilter { + Base const& base; + CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {} +}; + +// Specialization for AutoFilter +template +CUTE_HOST_DEVICE +void +copy(AutoFilter const& copy_op, + Tensor const& src, + Tensor & dst) +{ + if constexpr (is_constant::value) { + auto dst_null = nullspace(dst.layout()); + + Tensor dst_n = zipped_divide(dst, dst_null); + Tensor src_n = zipped_divide(src, dst_null); + + CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); + CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy"); + + copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); + } else { + copy(copy_op.base, src, dst); + } +} + // Auto-vectorizing copy for static layouts template @@ -292,7 +304,11 @@ copy(Tensor const& src, { if constexpr (is_static::value && is_static::value) { // Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned - return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned. + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst); } else { // Do not assume that dynamic layouts are aligned. return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst); @@ -307,7 +323,12 @@ void copy_aligned(Tensor const& src, Tensor & dst) { - return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else { + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + } } // Specializaton for Atom AutoVectorizingCopyAssumedAlignment @@ -379,4 +400,146 @@ copy(Copy_Atom, CA_Args...> const& } #endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +// +// Decay TiledCopy to CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(TiledCopy const& tiled_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + return copy_if(static_cast(tiled_copy), pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(TiledCopy const& tiled_copy, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast(tiled_copy), src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(ThrCopy const& thr_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) = delete; + +template +CUTE_HOST_DEVICE +void +copy(ThrCopy const& thr_copy, + Tensor const& src, + Tensor & dst) = delete; + +// +// Catch uncaught policies +// + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& cpy, + PredTensor const& prd, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& cpy, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& copy_policy, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_policy, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& copy_policy, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_policy, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor && dst) +{ + return copy_aligned(src, dst); +} + } // end namespace cute diff --git a/include/cute/arch/copy.hpp b/include/cute/arch/copy.hpp index 5139289995..47dbef2f55 100644 --- a/include/cute/arch/copy.hpp +++ b/include/cute/arch/copy.hpp @@ -39,7 +39,7 @@ namespace cute { // -// Direct Copy for any type +// Direct Copy for any specific types // template @@ -48,21 +48,15 @@ struct UniversalCopy using SRegisters = S[1]; using DRegisters = D[1]; - template - CUTE_HOST_DEVICE static constexpr void - copy(S_ const& src, - D_ & dst) - { - dst = static_cast(static_cast(src)); - } + // Sanity + static_assert(sizeof_bits_v >= 8); + static_assert(sizeof_bits_v >= 8); - // Accept mutable temporaries - template CUTE_HOST_DEVICE static constexpr void - copy(S_ const& src, - D_ && dst) + copy(S const& src, + D & dst) { - UniversalCopy::copy(src, dst); + dst = src; } }; @@ -92,6 +86,12 @@ using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>; +// +// Copy policy automatically selecting between +// UniversalCopy and cp.async , based on type and memory space. +// +struct AutoCopyAsync {}; + // // Global memory prefetch into L2 // diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp index 60777f2203..17860dd40f 100644 --- a/include/cute/arch/mma_sm80.hpp +++ b/include/cute/arch/mma_sm80.hpp @@ -2040,10 +2040,8 @@ struct SM80_16x8x64_S32U4U4S32_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - // MMA 8x8x128 TN -struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC +struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC { using DRegisters = uint32_t[2]; using ARegisters = uint32_t[1]; @@ -2056,9 +2054,9 @@ struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( - "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " "{%0, %1}," "{%2}," "{%3}," @@ -2068,7 +2066,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC "r"(b0), "r"(c0), "r"(c1)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); #endif } }; @@ -2076,7 +2074,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x128 TN -struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC +struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC { using DRegisters = uint32_t[4]; using ARegisters = uint32_t[2]; @@ -2089,9 +2087,9 @@ struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( - "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," @@ -2101,7 +2099,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC "r"(b0), "r"(c0), "r"(c1), "r"(c2), "r"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); #endif } }; @@ -2109,7 +2107,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x256 TN -struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC +struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC { using DRegisters = uint32_t[4]; using ARegisters = uint32_t[4]; @@ -2122,9 +2120,9 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) asm volatile( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," @@ -2134,7 +2132,7 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC "r"(b0), "r"(b1), "r"(c0), "r"(c1), "r"(c2), "r"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); #endif } }; @@ -2142,7 +2140,7 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 8x8x128 TN -struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC +struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC { using DRegisters = uint32_t[2]; using ARegisters = uint32_t[1]; @@ -2155,9 +2153,9 @@ struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( - "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1}," "{%2}," "{%3}," @@ -2167,7 +2165,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC "r"(b0), "r"(c0), "r"(c1)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); #endif } }; @@ -2175,7 +2173,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x128 TN -struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC +struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC { using DRegisters = uint32_t[4]; using ARegisters = uint32_t[2]; @@ -2188,9 +2186,9 @@ struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( - "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," @@ -2200,7 +2198,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC "r"(b0), "r"(c0), "r"(c1), "r"(c2), "r"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); #endif } }; @@ -2208,7 +2206,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x256 TN -struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC +struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC { using DRegisters = uint32_t[4]; using ARegisters = uint32_t[4]; @@ -2221,9 +2219,9 @@ struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," @@ -2233,7 +2231,7 @@ struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC "r"(b0), "r"(b1), "r"(c0), "r"(c1), "r"(c2), "r"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); #endif } }; diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index dd6b4e52a0..75b7aa4de6 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -100,16 +100,16 @@ struct Copy_Atom, CopyInternalType> if constexpr (is_constant::value || is_constant::value) { // Dispatch to unpack to execute instruction - return copy_unpack(*this, src, dst); - } else - if constexpr (is_tuple::value && - is_tuple::value) { + return copy_unpack(static_cast(*this), src, dst); + } else if constexpr (is_tuple::value && + is_tuple::value) { // If the size of the src/dst doesn't match the instruction, // recurse this rank-1 layout by peeling off the mode // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); } else { - static_assert(dependent_false, "No instruction match and no recursion possible."); + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); } } diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index bfbeb4ea51..ac746a64e1 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -92,23 +92,29 @@ struct Copy_Traits> using RefLayout = SrcLayout; }; +// Extract a CPY_Op from a CPY_Traits +template +struct CPY_Op {}; + +template +struct CPY_Op> { + using type = CPY_Op_Arg; +}; + // // Generic copy_unpack for common argument-based Copy_Traits // -template CUTE_HOST_DEVICE constexpr void -copy_unpack(Copy_Traits const&, - Tensor const& src, - Tensor & dst) +copy_unpack(AnyCPYTraits const&, + Tensor const& src, + Tensor & dst) { - // Specializations can generalize on these checks - //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); - //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); - + using CopyOp = typename CPY_Op::type; using RegistersSrc = typename CopyOp::SRegisters; using RegistersDst = typename CopyOp::DRegisters; using RegTypeSrc = typename remove_extent::type; @@ -129,18 +135,15 @@ copy_unpack(Copy_Traits const&, rD, make_int_sequence{}); } -// // Accept mutable temporaries -// - -template CUTE_HOST_DEVICE constexpr void -copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor && dst) +copy_unpack(AnyCPYTraits const& traits, + Tensor const& src, + Tensor && dst) { copy_unpack(traits, src, dst); } diff --git a/include/cute/atom/copy_traits_sm80.hpp b/include/cute/atom/copy_traits_sm80.hpp index e5ff0b7b35..3795f52a89 100644 --- a/include/cute/atom/copy_traits_sm80.hpp +++ b/include/cute/atom/copy_traits_sm80.hpp @@ -51,13 +51,6 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - - // Construct a zfill variant with a given predicate value - CUTE_HOST_DEVICE constexpr - Copy_Traits> - with(bool pred) const { - return {pred}; - } }; template @@ -73,13 +66,6 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - - // Construct a zfill variant with a given predicate value - CUTE_HOST_DEVICE constexpr - Copy_Traits> - with(bool pred) const { - return {pred}; - } }; template @@ -96,8 +82,15 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - // Predicate value that determines whether to load or zfill - bool pred = false; + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } // Overload copy_unpack for zfill variant to pass the predicate into the op template > // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - // Predicate value that determines whether to load or zfill - bool pred = false; + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } // Overload copy_unpack for zfill variant to pass the predicate into the op template > } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Element copy selector -template -CUTE_HOST_DEVICE constexpr -auto -select_elementwise_copy(SrcTensor const&, DstTensor const&) -{ - using SrcType = typename SrcTensor::value_type; - using DstType = typename DstTensor::value_type; - -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - if constexpr (is_gmem::value && is_smem::value && - sizeof(SrcType) == sizeof(DstType) && - (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16)) - { - return SM80_CP_ASYNC_CACHEALWAYS{}; - } else { - return UniversalCopy{}; - } - - CUTE_GCC_UNREACHABLE; -#else - return UniversalCopy{}; -#endif -} - -} +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 3738cc3962..4ad7f80851 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -58,37 +58,31 @@ struct AuxTmaParams { }; // Utility for unpacking TMA_LOAD arguments into a CopyOp -template +template struct TMA_LOAD_Unpack { - template CUTE_HOST_DEVICE friend constexpr void copy_unpack(Copy_Traits const& traits, Tensor const& src, Tensor & dst) { + static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); + auto src_coord = src.data().coord_; - if constexpr (detail::is_prefetch) { - return detail::explode_tuple(detail::CallCOPY{}, - traits.opargs_, tuple_seq{}, - src_coord, tuple_seq{}); - } else { - static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); - void* dst_ptr = cute::raw_pointer_cast(dst.data()); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); #if 0 - auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); - printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z, - int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); #endif - return detail::explode_tuple(detail::CallCOPY{}, - traits.opargs_, tuple_seq{}, - make_tuple(dst_ptr), seq<0>{}, - src_coord, tuple_seq{}); - } + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); } }; @@ -131,7 +125,7 @@ struct Copy_Traits [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent - return {{}, {&tma_desc_, &tma_mbar, static_cast(cache_hint)}}; + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; } // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) @@ -143,7 +137,7 @@ struct Copy_Traits [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent - return {{}, {new_tma_desc, &tma_mbar, static_cast(cache_hint)}}; + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; } // Generate the TMA coord tensor @@ -167,7 +161,7 @@ struct Copy_Traits // The executable SM90_TMA_LOAD with tma_desc and tma_mbar template struct Copy_Traits - : TMA_LOAD_Unpack + : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -183,12 +177,15 @@ struct Copy_Traits uint64_t*, // smem mbarrier uint64_t // cache hint > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} }; // The prefetch for SM90_TMA_LOAD with tma_desc template struct Copy_Traits - : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -206,6 +203,19 @@ struct Copy_Traits CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const& traits) : opargs_({&traits.tma_desc_}) {} + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto src_coord = src.data().coord_; + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -246,7 +256,7 @@ struct Copy_Traits uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { - return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + return {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; } // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) @@ -257,7 +267,7 @@ struct Copy_Traits uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { - return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + return {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; } // Generate the TMA coord tensor @@ -281,7 +291,7 @@ struct Copy_Traits // The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask template struct Copy_Traits - : TMA_LOAD_Unpack + : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -298,43 +308,17 @@ struct Copy_Traits uint16_t, // multicast mask uint64_t // cache hint > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) + : opargs_(desc, mbar, mask, hint) {} }; ////////////////////////////////////////////////////////////////////////////// ///////////////////////////// TMA_STORE ////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// -// Utility for unpacking TMA_STORE arguments into a CopyOp -template -struct TMA_STORE_Unpack -{ - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); - - void const* const desc_ptr = traits.tma_desc_; - void const* const src_ptr = cute::raw_pointer_cast(src.data()); - auto dst_coord = dst.data().coord_; -#if 0 - auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); - printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z, - int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); -#endif - return detail::explode_tuple(detail::CallCOPY{}, - make_tuple(desc_ptr, src_ptr), seq<0,1>{}, - dst_coord, tuple_seq{}); - } -}; - -struct SM90_TMA_STORE_OP : SM90_TMA_STORE {}; +struct SM90_TMA_STORE_PTR : SM90_TMA_STORE {}; // The executable SM90_TMA_STORE with tma_desc template @@ -369,6 +353,13 @@ struct Copy_Traits return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); } + // Construct new TMA_STORE with (unsafe) swapped out TMA descriptor ptr (for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + template CUTE_HOST_DEVICE friend constexpr void @@ -393,19 +384,11 @@ struct Copy_Traits make_tuple(desc_ptr, src_ptr), seq<0,1>{}, dst_coord, tuple_seq{}); } - - // Construct Copy_Traits executable (w/ swapped out TMA descriptor) for SM90_TMA_STORE (for grouped gemm/ptr array gemm) - CUTE_HOST_DEVICE constexpr - Copy_Traits - with(TmaDescriptor const* new_tma_desc) const { - return {{}, new_tma_desc}; - } }; -// The executable SM90_TMA_STORE with tma_desc +// Same as SM90_TMA_STORE, but with an unsafe TMA Desc PTR instead template -struct Copy_Traits - : TMA_STORE_Unpack +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -417,6 +400,31 @@ struct Copy_Traits // SM90_TMA_STORE arguments TmaDescriptor const* tma_desc_; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = traits.tma_desc_; + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -520,7 +528,7 @@ struct Copy_Traits CUTE_HOST_DEVICE constexpr Copy_Traits with(uint64_t& bulk_mbar) const { - return {{&bulk_mbar}}; + return {&bulk_mbar}; } template CUTE_HOST_DEVICE constexpr Copy_Traits with(uint64_t& bulk_mbar) const { - return {{&bulk_mbar}}; + return {&bulk_mbar}; } }; @@ -1391,19 +1399,46 @@ tma_partition(Copy_Atom const& copy_atom, return cute::make_tuple(gresult, sresult); } +// Explicit defaults for cta_coord and cta_layout +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + return tma_partition(copy_atom, Int<0>{}, Layout<_1,_0>{}, stensor, gtensor); +} + // TMA Multicast Masks Calculation template CUTE_HOST_DEVICE constexpr -auto +uint16_t create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, CtaCoord const& cta_coord_vmnk) { auto cta_coord_slicer = replace(cta_coord_vmnk, _); auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk); - // Get the instruction code + uint16_t mcast_mask = 0; - for (int i = 0; i < size(cta_layout); ++i) { - mcast_mask |= uint16_t(1) << cta_layout(i); + if constexpr (rank_v == 1 and depth_v <= 1 and + not is_static::value) { + // Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout + mcast_mask = uint16_t(1); + // Smear by stride<0> (may want to predicate on stride<0> mag?) + mcast_mask |= mcast_mask << (1*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (2*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (4*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (8*stride<0>(cta_layout)); + // Select shape<0> + mcast_mask &= (uint16_t(-1) >> (16 - shape<0>(cta_layout) * stride<0>(cta_layout))); + } else { + // Get the instruction code -- generic path + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); + } } // Shift by the instruction's elected block rank (dynamic) mcast_mask <<= elected_cta; diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index bf40827436..7cb4fe3df2 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -250,12 +250,12 @@ struct TiledMMA : MMA_Atom auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) // Tile the tensor for the Atom - auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), make_layout(size<1>(AtomShape_MNK{}))); - auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN)) + auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) // Transform the Atom mode from (M,K) to (Thr,Val) - auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) // Tile the tensor for the C-threads auto thr_tile = make_tile(_, @@ -604,16 +604,15 @@ CUTE_HOST_DEVICE constexpr auto partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) { - constexpr int R = rank_v; - static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - auto V = shape<1>(typename TiledMMA::AtomLayoutC_TV{}); - auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK)); - auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK)); - return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN)); + auto dummy = make_layout(shape(shape_MN)); + auto dummy_tv = mma.thrfrg_C(dummy); + // Slice+rearrange like partition_C + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + } + template CUTE_HOST_DEVICE constexpr auto @@ -632,14 +631,12 @@ CUTE_HOST_DEVICE constexpr auto partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) { - constexpr int R = rank_v; - static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - auto V = shape<1>(typename TiledMMA::AtomLayoutA_TV{}); - auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK)); - auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK)); - return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK)); + auto dummy = make_layout(shape(shape_MK)); + auto dummy_tv = mma.thrfrg_A(dummy); + // Slice+rearrange like partition_A + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + } template @@ -647,14 +644,12 @@ CUTE_HOST_DEVICE constexpr auto partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) { - constexpr int R = rank_v; - static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - auto V = shape<1>(typename TiledMMA::AtomLayoutB_TV{}); - auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK)); - auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK)); - return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK)); + auto dummy = make_layout(shape(shape_NK)); + auto dummy_tv = mma.thrfrg_B(dummy); + // Slice+rearrange like partition_B + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + } // diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp index 706b10d889..5f7e73e467 100644 --- a/include/cute/atom/mma_traits_sm80.hpp +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -419,6 +419,203 @@ template <> struct MMA_Traits : MMA_Traits {}; +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V8) -> (M8,N32) + using ALayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V16) -> (M16,N32) + using ALayout = Layout, Shape < _8, _2>>, + Stride, Stride<_16, _8>>>; + // (T32,V8) -> (M8,N32) + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,N64) + using ALayout = Layout, Shape < _8, _2, _2>>, + Stride, Stride<_16, _8, _512>>>; + // (T32,V16) -> (M8,N64) + using BLayout = Layout, Shape <_8, _2>>, + Stride, Stride<_8, _256>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + /////////////////////////////////////////////////////////////////////////////// /////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -440,9 +637,13 @@ struct MMA_Traits using CLayout = SM80_16x8_Row; }; +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 & b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + template <> struct MMA_Traits - :MMA_Traits {}; + : MMA_Traits {}; template<> struct MMA_Traits @@ -455,7 +656,7 @@ struct MMA_Traits using Shape_MNK = Shape<_8,_8,_128>; using ThrID = Layout<_32>; using ALayout = Layout,_32>, - Stride,_8>>; + Stride,_8>>; using BLayout = Layout,_32>, Stride,_8>>; using CLayout = SM80_8x8_Row; @@ -472,7 +673,7 @@ struct MMA_Traits using ValTypeA = cute::uint1b_t; using ValTypeB = cute::uint1b_t; using ValTypeC = int32_t; - + using Shape_MNK = Shape<_16,_8,_128>; using ThrID = Layout<_32>; using ALayout = Layout,Shape<_32,_2>>, diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index b02f5b3afd..8f59ff55b4 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -1128,7 +1128,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp index 27c41ad338..161dc7ecf0 100644 --- a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -7735,4 +7735,4 @@ struct MMA_Traits +# include #endif // _MSC_VER #if defined(__CUDACC_RTC__) diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 57db56aba5..48d416f45b 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -100,20 +100,30 @@ struct subbyte_reference // Copy Ctor CUTE_HOST_DEVICE constexpr - subbyte_reference(subbyte_reference const& other) { - *this = element_type(other); + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); } // Copy Assignment CUTE_HOST_DEVICE constexpr - subbyte_reference& operator=(subbyte_reference const& other) { - return *this = element_type(other); + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); } // Assignment template CUTE_HOST_DEVICE constexpr - enable_if_t, subbyte_reference&> operator=(element_type x) + enable_if_t, subbyte_reference&> operator=(value_type x) { static_assert(is_same_v, "Do not specify template arguments!"); storage_type item = (reinterpret_cast(x) & BitMask); @@ -149,11 +159,11 @@ struct subbyte_reference // Value CUTE_HOST_DEVICE - element_type get() const + value_type get() const { if constexpr (is_same_v) { // Extract to bool -- potentially faster impl return bool((*ptr_) & (BitMask << idx_)); - } else { // Extract to element_type + } else { // Extract to value_type // Extract from the current storage element auto item = storage_type((ptr_[0] >> idx_) & BitMask); @@ -165,13 +175,13 @@ struct subbyte_reference item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); } - return reinterpret_cast(item); + return reinterpret_cast(item); } } - // Extract to type element_type + // Extract to type value_type CUTE_HOST_DEVICE constexpr - operator element_type() const { + operator value_type() const { return get(); } @@ -341,6 +351,14 @@ recast_ptr(subbyte_iterator const& x) { CUTE_GCC_UNREACHABLE; } +// Dynamic pointers have unknown static alignment +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(subbyte_iterator const& x) { + return {}; +} + template CUTE_HOST_DEVICE void print(subbyte_iterator const& x) { @@ -352,6 +370,7 @@ CUTE_HOST_DEVICE void print(subbyte_reference const& x) { print(x.get()); } + // // array_subbyte // Statically sized array for non-byte-aligned data types diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index bc1b54efbc..26195a4782 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1830,7 +1830,7 @@ recast_layout(Layout const& layout) return upcast(layout); } else { - static_assert(dependent_false, "Recast not supported."); + return downcast(upcast(layout)); } CUTE_GCC_UNREACHABLE; diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 3e5f836279..26ae8dc76c 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -616,7 +616,7 @@ recast_layout(ComposedLayout const& layout) return upcast(layout); } else { - static_assert(dependent_false, "Recast not supported."); + return downcast(upcast(layout)); } CUTE_GCC_UNREACHABLE; } @@ -631,6 +631,15 @@ max_alignment(ComposedLayout const& layout) return Int<1>{}; } +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + return Layout<_1,_0>{}; +} + // // Display utilities // diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 1b1432533a..a614bdb2d9 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -154,13 +154,6 @@ operator*(C, R) { return {}; } -template -CUTE_HOST_DEVICE constexpr -typename R::type -operator/(C, R) { - return {}; -} - // Product with dynamic type needs to produce an integer... template ::value)> @@ -179,6 +172,13 @@ operator*(R, C const& c) { return c * R::num / R::den; } +template +CUTE_HOST_DEVICE constexpr +auto +operator/(C const& c, R) { + return c * R{}; +} + template CUTE_HOST_DEVICE constexpr typename R::type @@ -200,6 +200,10 @@ operator+(C, R) { return {}; } +///////////////// +// Comparisons // +///////////////// + template CUTE_HOST_DEVICE constexpr bool_constant::num == R::num && R::den == R::den> @@ -221,6 +225,31 @@ operator==(C, R) { return {}; } +/////////////////////// +// Special functions // +/////////////////////// + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(C, R) { + return {}; +} + template CUTE_HOST_DEVICE constexpr typename R::type diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp index 07444331ff..b9943b8ca3 100644 --- a/include/cute/numeric/numeric_types.hpp +++ b/include/cute/numeric/numeric_types.hpp @@ -46,6 +46,7 @@ template static constexpr auto sizeof_bits_v = sizeof_bits::value; using cutlass::bits_to_bytes; +using cutlass::bytes_to_bits; using cutlass::is_subbyte; diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 4cfa129cce..cc49b6a356 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -214,6 +214,14 @@ make_smem_ptr(void const* ptr) { return make_smem_ptr(recast_ptr(ptr)); } +// nullptr_t overload for make_smem_ptr(nullptr) disambiguation +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(decltype(nullptr)) { // nullptr_t + return make_smem_ptr(recast_ptr(nullptr)); +} + // The smem tag is invariant over type-recast template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index 90ca0ceb6e..57ad0b3cde 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -30,9 +30,10 @@ **************************************************************************************************/ #pragma once -#include // CUTE_HOST_DEVICE -#include // cute::sizeof_bits -#include // cute::declval, cute::void_t, etc +#include // CUTE_HOST_DEVICE +#include // cute::sizeof_bits +#include // Int<0> +#include // cute::declval, cute::void_t, etc namespace cute { @@ -115,6 +116,14 @@ raw_pointer_cast(T* ptr) { return ptr; } +// The statically-known alignment of a dynamic pointer is unknown +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(T*) { + return {}; +} + // // A very simplified iterator adaptor. // Derived classed may override methods, but be careful to reproduce interfaces exactly. @@ -169,6 +178,13 @@ raw_pointer_cast(iter_adaptor const& x) { return raw_pointer_cast(x.ptr_); } +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(iter_adaptor const& x) { + return max_alignment(x.ptr_); +} + // // counting iterator -- quick and dirty // diff --git a/include/cute/pointer_swizzle.hpp b/include/cute/pointer_swizzle.hpp index 720b9b1246..1a802cfdc6 100644 --- a/include/cute/pointer_swizzle.hpp +++ b/include/cute/pointer_swizzle.hpp @@ -147,6 +147,14 @@ recast_ptr(swizzle_ptr const& ptr) { return make_swizzle_ptr(recast_ptr(ptr.get()), SwizzleFn{}); } +// The statically-known alignment of a swizzle pointer is the alignment of the swizzle function converted to bits +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(swizzle_ptr const&) { + return Int<8>{} * max_alignment(SwizzleFn{}); +} + // // Display utilities // diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 1324360eba..7f7161bc32 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -447,7 +447,7 @@ recast_layout(Swizzle const& swizzle) return upcast(swizzle); } else { - static_assert(dependent_false, "Recast not supported."); + return downcast(upcast(layout)); } CUTE_GCC_UNREACHABLE; } @@ -457,7 +457,7 @@ CUTE_HOST_DEVICE constexpr auto max_alignment(Swizzle const&) { - return Int<1 << M>{}; + return Int<(1 << M)>{}; } template diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index 3564c667b1..2be19c15e3 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -84,6 +84,8 @@ struct ArrayEngine }; // Specialization for sparse_elem tensor allocation/iteration +// NOTE: This can and should be used for allocation of SMEM as well! +// Fuse these two ArrayEngines? template struct ArrayEngine, N> { @@ -858,6 +860,17 @@ max_common_layout(Tensor const& a, CUTE_GCC_UNREACHABLE; } +/* Return the maximum (statically known) alignment of a Tensor in the number of bits + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Tensor const& t) +{ + return gcd(max_alignment(t.data()), + max_alignment(t.layout()) * static_value>()); +} + // // Key algebraic operations -- Composition, Divide, and Product // diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index 86da7cae91..2645444369 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -123,7 +123,7 @@ bool block([[maybe_unused]] int bid) { #if defined(__CUDA_ARCH__) - return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; + return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == static_cast(bid); #else return true; #endif @@ -134,7 +134,7 @@ bool thread([[maybe_unused]] int tid, [[maybe_unused]] int bid) { #if defined(__CUDA_ARCH__) - return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); + return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == static_cast(tid)) && block(bid); #else return true; #endif diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index e663b569c6..a3074ef947 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -141,9 +141,15 @@ using CUTE_STL_NAMESPACE::common_type_t; using CUTE_STL_NAMESPACE::remove_pointer; using CUTE_STL_NAMESPACE::remove_pointer_t; +using CUTE_STL_NAMESPACE::add_pointer; +using CUTE_STL_NAMESPACE::add_pointer_t; + using CUTE_STL_NAMESPACE::alignment_of; using CUTE_STL_NAMESPACE::alignment_of_v; +using CUTE_STL_NAMESPACE::is_pointer; +using CUTE_STL_NAMESPACE::is_pointer_v; + // using CUTE_STL_NAMESPACE::declval; diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index c96897324a..460531aa89 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -47,6 +47,99 @@ namespace cutlass { namespace arch { //////////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_DEVICE void fence_view_async_shared(); + +namespace detail { // namespace detail begin + +// Single threaded versions that need to be called in an elect_one region +template +CUTLASS_DEVICE +void initialize_barrier_array(T ptr, int arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + ptr[i].init(arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array(uint64_t *ptr, int arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + T::init(&ptr[i], arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + full_barriers[i].init(full_barrier_arv_cnt); + empty_barriers[i].init(empty_barrier_arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); + EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); + } +} + +// Aligned versions that need to be call warp wide +template +CUTLASS_DEVICE +void initialize_barrier_array_aligned(T ptr, int arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + ptr[i].init(arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_aligned(uint64_t *ptr, int arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + T::init(&ptr[i], arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair_aligned(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + full_barriers[i].init(full_barrier_arv_cnt); + empty_barriers[i].init(empty_barrier_arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair_aligned(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); + EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); + } + } +} + +} // namespace detail end + + // Enumerates the reserved named barriers to avoid potential conflicts // This enum class specifies the NamedBarriers reserved by CUTLASS. enum class ReservedNamedBarriers { diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index b0f750063c..0fc60f41db 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -35,6 +35,8 @@ #pragma once +#include "cutlass/platform/platform.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// // SM90 @@ -79,3 +81,5 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index 6b487a7377..0e957c72ae 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -35,6 +35,7 @@ #pragma once #include "cutlass/array.h" +#include "cutlass/detail/helper_macros.hpp" #include "cutlass/layout/matrix.h" #include "cute/arch/copy_sm75.hpp" #include "cute/arch/util.hpp" @@ -50,7 +51,7 @@ template < /// .x1, .x2, or .x4 int MatrixCount > -inline __device__ void ldsm(Array & D, void const* ptr); +CUTLASS_DEVICE void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -59,19 +60,19 @@ inline __device__ void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// /// CUTLASS helper to get SMEM pointer -inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { +CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { return cute::cast_smem_ptr_to_uint(ptr); } /// CUTLASS helper to get SMEM pointer -inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { +CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void const *ptr) { return cutlass_get_smem_pointer(const_cast(ptr)); } ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -95,7 +96,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -119,7 +120,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -147,7 +148,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -171,7 +172,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -195,7 +196,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h index 6471de8a87..28bb46382c 100644 --- a/include/cutlass/arch/mma_sm70.h +++ b/include/cutlass/arch/mma_sm70.h @@ -33,11 +33,7 @@ */ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 6cced190e8..a39ededbe0 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/arch/wmma.h" diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index f990c1ac27..19d78bf20e 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "mma.h" diff --git a/include/cutlass/arch/mma_sm89.h b/include/cutlass/arch/mma_sm89.h index fe4b7eb7e6..d8a75b6623 100644 --- a/include/cutlass/arch/mma_sm89.h +++ b/include/cutlass/arch/mma_sm89.h @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "mma.h" diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index 1183ee5e05..16108f0a1e 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sparse_sm80.h b/include/cutlass/arch/mma_sparse_sm80.h index 7041d04dd4..ed2a5ad019 100644 --- a/include/cutlass/arch/mma_sparse_sm80.h +++ b/include/cutlass/arch/mma_sparse_sm80.h @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sparse_sm89.h b/include/cutlass/arch/mma_sparse_sm89.h index c092df7682..2fae35be42 100644 --- a/include/cutlass/arch/mma_sparse_sm89.h +++ b/include/cutlass/arch/mma_sparse_sm89.h @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/simd.h b/include/cutlass/arch/simd.h index 3104746e58..f670fc293f 100644 --- a/include/cutlass/arch/simd.h +++ b/include/cutlass/arch/simd.h @@ -34,8 +34,8 @@ #pragma once -#include "../array.h" -#include "../numeric_types.h" +#include "cutlass/arch/array.h" +#include "cutlass/arch/numeric_types.h" namespace cutlass { namespace arch { diff --git a/include/cutlass/arch/synclog.hpp b/include/cutlass/arch/synclog.hpp index ea683859a3..8cf65ad73e 100644 --- a/include/cutlass/arch/synclog.hpp +++ b/include/cutlass/arch/synclog.hpp @@ -59,7 +59,7 @@ constexpr uint32_t synclog_cap = 1 << 26; inline std::mutex synclog_mutex; inline std::vector synclog_buf_list; #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -inline __device__ uint32_t* synclog_buf; +CUTLASS_DEVICE uint32_t* synclog_buf; #endif CUTLASS_DEVICE diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index 19fda4f85d..d75ee2b075 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index 4a2689058b..b644181b80 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index 4663e95c78..f603605128 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 62e9469497..e85d19facf 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -2573,20 +2573,8 @@ Array fma(Array const &a, Array const &b, T c) { return op(a, b, c); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -} // namespace cutlass - //////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/array_subbyte.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { + //////////////////////////////////////////////////////////////////////////////////////////////////// // AlignedArray @@ -2606,9 +2594,10 @@ class alignas(Alignment) AlignedArray: public Array { }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/array_subbyte.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index eb77a9310e..d2e0e5efdb 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -554,6 +554,8 @@ struct Array { //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/blas3.h b/include/cutlass/blas3.h index ee5587d1cc..d41f1ee61e 100644 --- a/include/cutlass/blas3.h +++ b/include/cutlass/blas3.h @@ -132,7 +132,7 @@ struct MantissaInBits { template <> struct MantissaInBits> { static int constexpr bits = 30; - static double constexpr error = 1.0e-15; + static double constexpr error = 1.0e-14; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 78862b0a09..0e5d898d0e 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -189,7 +189,7 @@ struct CollectiveConv< -problem_shape.dilation[NumSpatialDimensions-1-i] : problem_shape.dilation[NumSpatialDimensions-1-i]; } - + return make_im2col_tma_copy( GmemTiledCopyA{}, tensor_a, @@ -225,7 +225,7 @@ struct CollectiveConv< auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); auto lower_srt = detail::compute_lower_srt(problem_shape); - + return make_im2col_tma_copy( GmemTiledCopyB{}, tensor_b, @@ -372,6 +372,96 @@ struct CollectiveConv< return false; } + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::ostringstream os; +#endif + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (input_stride[i] != input_shape_size) { + os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n input_shape_size: " << input_shape_size + << "\n input_shape: " << input_shape + << "\n input_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (output_stride[i] != output_shape_size) { + os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n output_shape_size: " << input_shape_size + << "\n output_shape: " << input_shape + << "\n output_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + if (problem_shape.groups > 1) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); return false; @@ -516,9 +606,9 @@ struct CollectiveConv< // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_producer_state); @@ -645,7 +735,7 @@ struct CollectiveConv< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/conv/convnd_problem_shape.hpp b/include/cutlass/conv/convnd_problem_shape.hpp index ffcc547fbd..cd2f674ff4 100644 --- a/include/cutlass/conv/convnd_problem_shape.hpp +++ b/include/cutlass/conv/convnd_problem_shape.hpp @@ -319,6 +319,7 @@ struct ConvProblemShape { // | ShapeB | KTRSC | KTRSC | NDHWC | // | ShapeC | NZPQK | NDHWC | KTRSC | // + // Input comes from calculate_xformed_act, which does NOT depend on ConvOp. CUTLASS_HOST_DEVICE constexpr void set_shape_stride_ABC( @@ -328,6 +329,31 @@ struct ConvProblemShape { TensorStride stride_flt, TensorExtent shape_xformed_act, TensorStride stride_xformed_act) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("*** set_shape_stride_ABC ***"); + printf("\n shape_act: "); + print(shape_act); + printf("\n stride_act: "); + print(stride_act); + printf("\n shape_flt: "); + print(shape_flt); + printf("\n stride_flt: "); + print(stride_flt); + printf("\n shape_xformed_act: "); + print(shape_xformed_act); + printf("\n stride_xformed_act: "); + print(stride_xformed_act); + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + printf("\n ConvOp: Fprop"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + printf("\n ConvOp: Dgrad"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + printf("\n ConvOp: Wgrad"); + } + printf("\n"); +#endif if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { shape_A = shape_act; @@ -353,6 +379,20 @@ struct ConvProblemShape { shape_C = shape_flt; stride_C = stride_flt; } +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n shape_A: "); + print(shape_A); + printf("\n stride_A: "); + print(stride_A); + printf("\n shape_B: "); + print(shape_B); + printf("\n stride_B: "); + print(stride_B); + printf("\n shape_C: "); + print(shape_C); + printf("\n stride_C: "); + print(stride_C); +#endif } // Get A extents. diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h index 5e4299564f..d4e98fa49e 100644 --- a/include/cutlass/conv/kernel/direct_convolution.h +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -40,6 +40,7 @@ #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" #include "cutlass/semaphore.h" #include "cutlass/tensor_ref.h" #include "cutlass/layout/tensor.h" @@ -155,7 +156,7 @@ struct DirectConvolutionParams { swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); // Dynamic SMEM usage because stride and dilation are runtime params. - smem_size_ = (max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); + smem_size_ = (cutlass::platform::max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index d778046c2f..fe884d7037 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -37,7 +37,7 @@ #if defined(__CUDACC_RTC__) #include #else -#include +#include #endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 1c8f56a652..2adfd2665f 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -85,7 +85,11 @@ namespace cutlass { #if !defined(__CUDACC_RTC__) +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) #include +#endif // (__CUDACC_VERSION__ >= 11.8) + #include #define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok @@ -100,7 +104,8 @@ namespace cutlass { #else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) -#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) +#if ((__CUDACC_VER_MAJOR__ >= 13) || \ + ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) \ #define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ template \ @@ -138,7 +143,7 @@ namespace cutlass { return reinterpret_cast(pfn)(args...); \ } -#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) +#endif // (__CUDACC_VERSION__ >= 12.5) #endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp index a4b288e7c9..9d8f9e2f1d 100644 --- a/include/cutlass/detail/collective.hpp +++ b/include/cutlass/detail/collective.hpp @@ -31,6 +31,7 @@ #pragma once #include "cute/container/tuple.hpp" +#include "cute/layout.hpp" // cute::size(shape) ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index f9f348b9f7..c175538efa 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -754,7 +754,6 @@ struct MixedInputUtils { cute::tuple& partitioned_extra_info, int const k_block) { - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); static_assert(cosize_v == cosize_v); diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 4cd895f147..039f5e841a 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -57,6 +57,12 @@ #define CUTLASS_DEVICE inline #endif +#if ! defined(_MSC_VER) +#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline)) +#else +#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]] +#endif + #define CUTLASS_HOST __host__ #define CUTLASS_GLOBAL __global__ static @@ -74,11 +80,11 @@ CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) #ifdef _MSC_VER // Provides support for alternative operators 'and', 'or', and 'not' -#include +#include #endif // _MSC_VER #if !defined(__CUDACC_RTC__) -#include +#include #endif #if defined(__CUDA_ARCH__) diff --git a/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp b/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp new file mode 100644 index 0000000000..914443dd0d --- /dev/null +++ b/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Mainloop Fusion configs specific for scale factors +*/ + +#pragma once + +#include // cute::void_t + +namespace cutlass::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ElementSFType { + using type = void; +}; + +template +struct ElementSFType> { + using type = typename CollectiveMainloop::ElementSF; +}; + +template +struct LayoutSFAType { + using type = void; +}; + +template +struct LayoutSFAType> { + using type = typename CollectiveMainloop::LayoutSFA; +}; + +template +struct LayoutSFBType { + using type = void; +}; + +template +struct LayoutSFBType> { + using type = typename CollectiveMainloop::LayoutSFB; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 7af5d96cf6..cc7caede49 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -34,8 +34,11 @@ #pragma once +#include // CUTLASS_HOST_DEVICE +#include // uint64_t + // __grid_constant__ was introduced in CUDA 11.7. -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) && !CUTLASS_CLANG_CUDA # define CUTLASS_GRID_CONSTANT_SUPPORTED #endif diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 759591b5dc..720dcc008a 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -422,7 +422,8 @@ struct CollectiveBuilder< Schedule, fusion::LinearCombination, cute::enable_if_t || - cute::is_same_v >> { + cute::is_same_v || + cute::is_same_v >> { // Passing void C disables source load using ElementC = cute::conditional_t, diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp index 0f6f329311..da7562b43a 100644 --- a/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -86,7 +86,7 @@ class DefaultEpilogueArray { static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -198,20 +198,30 @@ class DefaultEpilogueArray { assert(0); } - InternalStrideC stride_c; - InternalStrideD stride_d; - if constexpr (!cute::is_same_v) { - // If grouped gemm - if (epilogue_op.is_source_needed()) { - stride_c = detail::get_epilogue_stride(params.dC[l_coord]); + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); } - stride_d = detail::get_epilogue_stride(params.dD[l_coord]); - } - else { - stride_c = detail::get_epilogue_stride(params.dC); - stride_d = detail::get_epilogue_stride(params.dD); - } - + }(); + // Represent the full output tensor ElementC const* ptr_C_l = nullptr; if (epilogue_op.is_source_needed()) { diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 6c0368e09b..23e57d99b8 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -157,7 +157,8 @@ struct EmptyStorage { template CUTLASS_HOST_DEVICE auto get_epilogue_stride(Stride stride){ - if constexpr (cute::is_base_of_v) { + if constexpr (cute::is_base_of_v|| + cute::is_base_of_v) { return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); } else { @@ -464,7 +465,7 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } }; -// SFINAE helpers for detecting beta/beta_ptr in EVT arguments. +// SFINAE helpers for detecting beta/beta_ptr/beta_ptr_array in EVT arguments. template struct has_beta { static constexpr bool value = false; @@ -485,6 +486,16 @@ struct has_beta_ptr +struct has_beta_ptr_array { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr_array> { + static constexpr bool value = true; +}; + } // namespace detail } // namespace collective } // namespace epilogue diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 84b6e14eeb..54fe9b1daf 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -328,7 +328,7 @@ class CollectiveEpilogue< } uint32_t transaction_bytes = TmaTransactionBytes; - typename Params::TMA_C tma_load_c = {}; + typename Params::TMA_C tma_load_c{}; if constexpr (is_source_supported) { ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); @@ -409,7 +409,7 @@ class CollectiveEpilogue< implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); } - if constexpr (not cute::is_void_v) { + if constexpr (is_source_supported) { constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); @@ -432,13 +432,16 @@ class CollectiveEpilogue< bool beta_implementable = true; - if constexpr (cute::is_void_v) { + if (cute::is_void_v || args.ptr_C == nullptr) { if constexpr (detail::has_beta::value) { beta_implementable = args.thread.beta == 0.0; } if constexpr (detail::has_beta_ptr::value) { beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; } + if constexpr (detail::has_beta_ptr_array::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; + } } if (!beta_implementable) { @@ -775,7 +778,7 @@ class CollectiveEpilogue< tRS_rC, thread_idx }; - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); @@ -1017,7 +1020,7 @@ class CollectiveEpilogue< Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) if constexpr (IsLoad) { - if (not cute::is_void_v) { + if (is_source_supported) { constexpr int C_tensormap_index = NumEpilogueWarpGroups; Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); @@ -1058,8 +1061,10 @@ class CollectiveEpilogue< // Replacing global_address for the next batch if constexpr (IsLoad) { if constexpr (is_source_supported) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, - params.ptr_C[next_batch]); + if (params.ptr_C != nullptr) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, + params.ptr_C[next_batch]); + } } } else if constexpr (is_destination_supported) { @@ -1087,18 +1092,20 @@ class CollectiveEpilogue< if constexpr (IsLoad) { if constexpr (is_source_supported) { - ElementC const* ptr_C = nullptr; - Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; + if (params.dC != nullptr) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, - prob_shape, - prob_stride); } } else if constexpr (is_destination_supported) { @@ -1166,7 +1173,7 @@ class CollectiveEpilogue< void tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { if constexpr (IsLoad) { - if constexpr (not cute::is_void_v) { + if constexpr (is_source_supported) { cute::tma_descriptor_fence_acquire(tensormap); } } diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index b96c4aea00..b3c7bf387d 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -94,7 +94,7 @@ class CollectiveEpilogue< SmemLayoutAtomD_, CopyOpR2S_, CopyAtomC_, - CopyOpR2R_, + CopyOpR2R_ > { public: // @@ -136,6 +136,9 @@ class CollectiveEpilogue< static_assert(not cute::is_void_v, "SmemElementD is void"); using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; + using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; @@ -239,14 +242,14 @@ class CollectiveEpilogue< struct Params { using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}), take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{})); using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD{}), take<0,2>(SmemLayoutD{}), EpilogueTile{}, @@ -273,9 +276,9 @@ class CollectiveEpilogue< auto [M, N, K, L] = problem_shape_MNKL; uint32_t transaction_bytes = TmaTransactionBytes; - typename Params::TMA_C tma_load_c = {}; + typename Params::TMA_C tma_load_c{}; if constexpr (is_source_supported) { - Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); tma_load_c = make_tma_copy_C_sm90( CopyOpG2S{}, tensor_c, @@ -285,7 +288,7 @@ class CollectiveEpilogue< typename Params::TMA_D tma_store_d; if constexpr (is_destination_supported) { - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); tma_store_d = make_tma_copy_C_sm90( CopyOpS2G{}, tensor_d, @@ -644,7 +647,18 @@ class CollectiveEpilogue< // Absolute coordinate tensors (dynamic) Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + Tensor tRS_cD_mn = [&]() { + if constexpr (IsUseR2R) { + // (t)hread-partition for ConsumerStoreCallbacks. + TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); + + return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + else { + return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + }(); // Relative coordinate tensors (static) Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index f829a2ff5d..a5f47f0832 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -50,6 +50,7 @@ struct EpilogueSimtVectorized {}; struct EpiloguePtrArraySimtVectorized {}; struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecializedTransposed {}; struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 3aed32710f..1ef06a538b 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -34,6 +34,7 @@ #include #include #include +#include // cute::false_type ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,9 +61,12 @@ struct FusionOperation { static constexpr int AlignmentScalar = 0; static constexpr bool IsScaleFactorSupported = false; static constexpr bool IsPerRowScaleSupported = false; + static constexpr bool IsPerColScaleSupported = false; + using ElementBias = void; static constexpr int AlignmentBias = 0; static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsPerColBiasSupported = false; static constexpr bool IsDePerRowBiasSupported = false; using ActivationFn = void; @@ -190,6 +194,24 @@ struct LinCombPerRowBiasEltAct static constexpr bool IsEltActSupported = true; }; +// D = activation(alpha * acc + beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltAct + : LinCombPerColBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + // D = activation(alpha * acc + beta * C + per-row bias) // aux = alpha * acc + beta * C + per-row bias template< @@ -214,6 +236,30 @@ struct LinCombPerRowBiasEltActAux static constexpr bool IsAuxOutSupported = true; }; +// D = activation(alpha * acc + beta * C + per-col bias) +// aux = alpha * acc + beta * C + per-col bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltActAux + : LinCombPerColBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + // D = activation(per-row alpha * acc + per-row beta * C + per-row bias) template< template class ActivationFn_, @@ -233,6 +279,46 @@ struct PerRowLinCombPerRowBiasEltAct static constexpr bool IsPerRowScaleSupported = true; }; +// D = per-column alpha * per-row alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementCompute_, + class ElementScalar_ = ElementCompute_, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct OuterProdLinComb : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr auto RoundStyle = RoundStyle_; + static constexpr bool IsSourceSupported = true; + static constexpr bool IsPerRowScaleSupported = true; + static constexpr bool IsPerColScaleSupported = true; +}; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerColScaleSupported = true; +}; + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias // if D is fp8 // D = scale_d * activation(Z) @@ -254,6 +340,27 @@ struct ScaledLinCombPerRowBiasEltAct static constexpr bool IsScaleFactorSupported = true; }; +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias // if D is fp8 // amax_d = max(abs(elements in activation(Z))) @@ -291,6 +398,43 @@ struct ScaledLinCombPerRowBiasEltActAmaxAux static constexpr bool IsAuxOutSupported = true; }; +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltActAmaxAux + : ScaledLinCombPerColBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + // Z = Aux // dY = alpha * acc + beta * C // D = d_activation(dY, Z) diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index e028846a4f..3e57fa0ba6 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -708,6 +708,105 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(alpha * acc + beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltAct = + Sm90EVT, + Sm90LinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = activation(alpha * acc + beta * C + per-row bias) // Aux = alpha * acc + beta * C + per-row bias) template< @@ -832,6 +931,132 @@ struct FusionCallbacks< }; ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(alpha * acc + beta * C + per_col bias) +// Aux = alpha * acc + beta * C + per_col bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerColBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = per-row alpha * acc + per-row beta * C + per-row bias template< class CtaTileShapeMNK, @@ -954,53 +1179,36 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace detail { - -template -constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; - -// We only apply the scaling factor if output is fp8 -template -struct ScaleOutOp { template using Op = cutlass::first; }; -template <> -struct ScaleOutOp { template using Op = cutlass::multiplies; }; -template <> -struct ScaleOutOp { template using Op = cutlass::multiplies; }; - -template -using amax = cutlass::maximum_absolute_value_reduction; // propogate nans - -}; // end namespace detail - -// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// D = per-col alpha * acc + per-col beta * C + per-column bias template< + int StagesC, class CtaTileShapeMNK, + class EpilogueTile, class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > -using Sm90ScaledLinCombPerRowBias = +using Sm90PerColLinCombPerColBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias > >; -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) template< + int StagesC, class CtaTileShapeMNK, + class EpilogueTile, template class ActivationFn, class ElementOutput, class ElementCompute, @@ -1008,16 +1216,532 @@ template< class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBiasEltAct = + Sm90EVT, + Sm90PerColLinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + +// We only apply the scaling factor if output is fp8 +template +struct ScaleOutOp { template using Op = cutlass::first; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; + +template +using amax = cutlass::maximum_absolute_value_reduction; // propogate nans + +}; // end namespace detail + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerRowBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerColBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > -using Sm90ScaledLinCombPerRowBiasEltAct = - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // activation(Z) - // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias - Sm90ScaledLinCombPerRowBias - >, - Sm90ScalarBroadcast // scale_d - >; +using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + template < int StagesC, @@ -1025,35 +1749,49 @@ template < int FragmentSize, bool ReuseSmemC, bool DelayTmaStore, + class GmemLayoutTagAux, template class ActivationFn, class ElementOutput, class ElementCompute, + class ElementAux, + class ElementAmax, class ElementBias, class ElementSource, class ElementScalar, + int AlignmentAux, int AlignmentBias, FloatRoundStyle RoundStyle, class CtaTileShapeMNK, - class EpilogueTile + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S > struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >, CtaTileShapeMNK, - EpilogueTile -> : Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { using Impl = - Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = - fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; struct Arguments { @@ -1071,6 +1809,9 @@ struct FusionCallbacks< ElementScalar const* scale_c_ptr = nullptr; ElementScalar const* scale_d_ptr = nullptr; + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; using StrideBeta = Stride<_0,_0,int64_t>; StrideAlpha dAlpha = {_0{}, _0{}, 0}; @@ -1083,34 +1824,113 @@ struct FusionCallbacks< using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + operator typename Impl::Arguments() const { - return - { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d - { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op - activation // unary args : activation - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } } }; @@ -1120,7 +1940,7 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias // if D is fp8 // amax_d = max(abs(elements in activation(Z))) // D = scale_d * activation(Z) @@ -1152,10 +1972,10 @@ template< int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > -using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8 = Sm90SplitTreeVisitor< - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias + Sm90ScaledLinCombPerColBias, // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT, // amax_d @@ -1197,14 +2017,14 @@ template< int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > -using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8 = // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT, // amax_d Sm90EVT, // activation(Z) Sm90EVT, // Aux = Z // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias + Sm90ScaledLinCombPerColBias > > >, @@ -1231,12 +2051,12 @@ template< int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > -using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, - Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< +using Sm90ScaledLinCombPerColBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8< CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle >, - Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< + Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8< CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > @@ -1268,7 +2088,7 @@ template < > struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + fusion::ScaledLinCombPerColBiasEltActAmaxAux< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >, @@ -1276,20 +2096,20 @@ struct FusionCallbacks< EpilogueTile, SmemLayoutAtom, CopyOpR2S -> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< +> : Sm90ScaledLinCombPerColBiasEltActAmaxAux< CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { using Impl = - Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + Sm90ScaledLinCombPerColBiasEltActAmaxAux< CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = - fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + fusion::ScaledLinCombPerColBiasEltActAmaxAux< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; @@ -1317,7 +2137,7 @@ struct FusionCallbacks< StrideAlpha dAlpha = {_0{}, _0{}, 0}; StrideBeta dBeta = {_0{}, _0{}, 0}; - using StrideBias = Stride<_1,_0,int64_t>; + using StrideBias = Stride<_0,_1,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -1354,7 +2174,7 @@ struct FusionCallbacks< { // ternary op : (scale_a * scale_b * alpha) * acc + bias {{alpha, scale_a, scale_b}, {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias @@ -1405,14 +2225,14 @@ struct FusionCallbacks< { // unary op : store(Z) { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias {{alpha, scale_a, scale_b}, {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}} + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias @@ -1679,6 +2499,87 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = per-column alpha * per-row alpha * acc + beta * c +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentScalar = 128 / sizeof_bits_v, // Alignment of per-column and per-row scaling vectors + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90OuterProdLinComb = + Sm90EVT, // c(beta) * c(C) + c(alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // c(alpha) * c(acc) + Sm90OuterProduct<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, Stride<_0,_1,int>, AlignmentScalar>, // alpha_col * alpha_row + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + OuterProdLinComb, + CtaTileShapeMNK, + EpilogueTile +> : Sm90OuterProdLinComb { + using Impl = Sm90OuterProdLinComb; + using Operation = OuterProdLinComb; + + struct Arguments { + + // Give a name and flat ordering to the fusion callback args + using StrideCol = Stride<_1,_0,int>; + using StrideRow = Stride<_0,_1,int>; + using StrideBeta = Stride<_0,_0,int>; + ElementScalar const* alpha_ptr_col = nullptr; + ElementScalar const* alpha_ptr_row = nullptr; + ElementScalar beta = static_cast(0); + ElementScalar const* beta_ptr = nullptr; + StrideCol dAlphaCol = {}; + StrideRow dAlphaRow = {}; + StrideBeta dBeta = {}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { + {beta, beta_ptr, dBeta}, // leaf args : beta + {}, // leaf args : C + { + { alpha_ptr_col, alpha_ptr_row, dAlphaCol, dAlphaRow }, // leaf args : alpha cols / rows + {}, // leaf args : acc + {} + }, + {} + }; + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = softmax(top_k(alpha * acc + beta * C)) template< int TopK, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 131d0ba5b9..321daa6bcc 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -266,8 +266,8 @@ struct Sm90TreeVisitor< auto const& scale_op = get<0>(Impl::ops); auto const& added_op = get<2>(Impl::ops); if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { - return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || - is_C_load_needed() || + return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || + is_C_load_needed() || added_op.is_producer_load_needed(); } else { @@ -408,8 +408,9 @@ template < > struct Sm90TreeVisitor< Sm90Compute, cutlass::epilogue::thread::ReLu> || - cute::is_same_v, cutlass::epilogue::thread::Clamp> >>, + cute::enable_if_t, cutlass::epilogue::thread::ReLu> || + cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU> >>, Sm90TreeVisitor< Sm90AuxStore< Stages, @@ -503,7 +504,8 @@ struct Sm90TreeVisitor< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { ElementCompute pre_relu = frg_compute[i]; - if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp>) { + if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU>) { frg_compute[i] = relu(frg_compute[i], params_compute); } else { diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index a22bed4e0d..66b1086efc 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -734,11 +734,12 @@ struct Sm90ScalarBroadcast { // Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors template< class Element, - class StrideMNL = Stride<_0,_0,_0>, + class StrideMNL_ = Stride<_0,_0,_0>, int BroadcastCount = 1, template class ReductionFn = multiplies > struct Sm90ScalarBroadcastPtrArray { + using StrideMNL = StrideMNL_; static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); @@ -780,8 +781,8 @@ struct Sm90ScalarBroadcastPtrArray { CUTLASS_DEVICE bool is_producer_load_needed() const { - // producer load is needed if Element is not void and we have multiple scalars - return !cute::is_void_v and size<2>(params_ptr->dScalar[0]) != 0; + // producer load is needed if Element is not void + return !cute::is_void_v; } CUTLASS_DEVICE bool @@ -814,7 +815,7 @@ struct Sm90ScalarBroadcastPtrArray { CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar[0]) != 0) { + if (size<2>(params_ptr->dScalar[0]) != 0) { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -1377,6 +1378,171 @@ struct Sm90ColBroadcast { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Do outer product from the column and row loaded +// +template< + int Stages, + class CtaTileShapeMNK, + class ElementScalar, + class StrideColMNL_ = Stride<_1,_0,int64_t>, /// NOTE: Batched scaling untested for now + class StrideRowMNL_ = Stride<_0,_1,int64_t>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = false // Fallback scalar broadcast for nullptr params +> +struct Sm90OuterProduct { + using StrideColMNL = StrideColMNL_; + using StrideRowMNL = StrideRowMNL_; + static_assert(Stages == 0, "OuterProduct doesn't support smem usage"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(!EnableNullptr, "Nullptr fallback not implemented"); + static_assert(is_static_v(StrideColMNL{}))> && + is_static_v(StrideRowMNL{}))>, "Only batch stride can be dynamic"); + static_assert(take<0,2>(StrideColMNL{}) == Stride<_1,_0>{} && + take<0,2>(StrideRowMNL{}) == Stride<_0,_1>{}, "Row and column incorrectly formatted"); + + // Accumulator distributes col/row elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + ElementScalar const* ptr_col = nullptr; + ElementScalar const* ptr_row = nullptr; + StrideColMNL dCol = {}; + StrideRowMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90OuterProduct() { } + + CUTLASS_HOST_DEVICE + Sm90OuterProduct(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorCol, class RTensorCol, + class GTensorRow, class RTensorRow + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorCol&& tCgCol, RTensorCol&& tCrCol, + GTensorRow&& tCgRow, RTensorRow&& tCrRow, + Params const& params) + : tCgCol(cute::forward(tCgCol)) + , tCrCol(cute::forward(tCrCol)) + , tCgRow(cute::forward(tCgRow)) + , tCrRow(cute::forward(tCrRow)) + , params(params) {} + + GTensorCol tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensorCol tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensorRow tCgRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensorRow tCrRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + + CUTLASS_DEVICE void + begin() { + + // Filter so we don't issue redundant copies over stride-0 modes + copy(filter(tCgCol), filter(tCrCol)); + copy(filter(tCgRow), filter(tCrRow)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_colrow; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_colrow[i] = static_cast(tCrCol_mn(epi_v * FragmentSize + i) * tCrRow_mn(epi_v * FragmentSize + i)); + } + return frg_colrow; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mRow, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks< + decltype(tCgCol), decltype(tCrCol), + decltype(tCgRow), decltype(tCrRow) + >( + cute::move(tCgCol), cute::move(tCrCol), + cute::move(tCgRow), cute::move(tCrRow), + params + ); + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // Batch matrix broadcast diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index f9ebe7393e..9c87be0809 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -293,11 +293,11 @@ template < class LayoutOrStrideMNL, class SmemLayoutAtom, // Unused class CopyOpR2S, // Unused - int Alignment, + int Alignment, bool EnableNullptr > struct Sm90AuxStore< - 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, + 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr > { using ElementAux = Element; @@ -343,7 +343,7 @@ struct Sm90AuxStore< CUTLASS_HOST_DEVICE Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; CUTLASS_DEVICE bool @@ -381,7 +381,7 @@ struct Sm90AuxStore< tC_cAux(cute::forward(tC_cAux)), problem_shape_mnl(problem_shape_mnl), params_ptr(params_ptr) {} - + GTensorR2G tC_gAux; RTensor tC_rAux; CTensorR2G tC_cAux; @@ -414,7 +414,7 @@ struct Sm90AuxStore< Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); - + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); @@ -451,7 +451,7 @@ struct Sm90AuxStore< // Predication support Tensor coordAux = make_identity_tensor(shape(mAux)); Tensor tC_cAux = sm90_partition_for_epilogue( - coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks( cute::move(tC_gAux), @@ -703,7 +703,6 @@ struct Sm90RowReduction { else if constexpr (FinalReduction) { auto problem_shape_mnkl = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -753,19 +752,18 @@ struct Sm90RowReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { -#if !defined(CUTLASS_SKIP_REDUCTION_INIT) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); if (args.ptr_row != nullptr) { return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); } return Status::kSuccess; } - else -#endif - if constexpr (FinalReduction) { + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1023,9 +1021,7 @@ struct Sm90RowReduction { } else { if (is_reduced_lane) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + copy_aligned(tCrRow, recast(tCgBuf)); } } sync_fn(); @@ -1054,9 +1050,7 @@ struct Sm90RowReduction { } else { if (is_reduced_lane) { - // Filter so we don't issue redunant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), filter(tCsBuf)); + copy_aligned(tCrRow, tCsBuf); } } sync_fn(); @@ -1296,7 +1290,6 @@ struct Sm90ColReduction { else if constexpr (FinalReduction) { auto problem_shape_mnkl = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1348,19 +1341,18 @@ struct Sm90ColReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { -#if !defined(CUTLASS_SKIP_REDUCTION_INIT) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); if (args.ptr_col != nullptr) { return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); } return Status::kSuccess; } - else -#endif - if constexpr (FinalReduction) { + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1522,9 +1514,7 @@ struct Sm90ColReduction { using ElementGmem = cute::conditional_t; Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); if (is_reduced_lane) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrCol), recast(filter(tCgBuf))); + copy_aligned(tCrCol, recast(tCgBuf)); } sync_fn(); } @@ -1542,9 +1532,7 @@ struct Sm90ColReduction { // Dump warp reduction to smem workspace Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); if (is_reduced_lane) { - // Filter so we don't issue redunant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrCol), filter(tCsBuf)); + copy_aligned(tCrCol, tCsBuf); } sync_fn(); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 4f7d99fa32..48f4756d1f 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -300,7 +300,6 @@ struct Sm90VisitorImplBase { tuple ops; }; - template struct Sm90VisitorImpl : Sm90VisitorImplBase { @@ -658,7 +657,6 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl(std::move(callbacks_tuple)); } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// template< diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 9f1cd77434..186e996602 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -258,6 +258,54 @@ struct LeakyReLU > { } }; +// Y = min((X <= threshold ? 0 : X), upper_bound) +template +struct ThresholdReLU { + static constexpr bool kIsHeavy = false; + + struct Arguments { + T threshold = T(0); + T upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); + }; + + CUTLASS_HOST_DEVICE + T operator()(T value, T threshold, T upper_bound) const { + minimum_with_nan_propagation mn; + + return mn((value <= threshold ? T(0) : value), upper_bound); + } + + CUTLASS_HOST_DEVICE + T operator()(T value, Arguments const& args = Arguments()) const { + return operator()(value, args.threshold, args.upper_bound); + } +}; + +template +struct ThresholdReLU> { + static constexpr bool kIsHeavy = false; + + using Arguments = typename ThresholdReLU::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, T threshold, T upper_bound) const { + ThresholdReLU relu; + + Array retvals; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + retvals[i] = relu(values[i], threshold, upper_bound); + } + + return retvals; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, Arguments const& args = Arguments()) const { + return operator()(values, args.threshold, args.upper_bound); + } +}; + // Tanh operator template struct Tanh { @@ -311,26 +359,7 @@ struct Sigmoid { }; template -struct Sigmoid > { - static const bool kIsHeavy = true; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &value) const { - Array y; - Sigmoid sigmoid_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - y[i] = sigmoid_op(value[i]); - } - - return y; - } -}; - -template -struct Sigmoid> { - using T = half_t; +struct Sigmoid> { static const bool kIsHeavy = true; CUTLASS_HOST_DEVICE @@ -450,6 +479,9 @@ struct HardSwish > { } }; +template +using ScaledHardSwish = Scale>; + // // GELU function definitions implemented as described by // Hendrycks, D., and Gimpel, K. in diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index f74a36af4b..c3aa3ff4fb 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -169,7 +169,7 @@ class LinearCombination { /// Constructs the function object, possibly loading from pointers in host memory CUTLASS_HOST_DEVICE - LinearCombination(Params const ¶ms, int group_idx = 0) { + explicit LinearCombination(Params const ¶ms, int group_idx) { if (params.alpha_ptr_array != nullptr && params.alpha_ptr_array[group_idx] != nullptr) { alpha_ = *(params.alpha_ptr_array[group_idx]); } @@ -190,6 +190,10 @@ class LinearCombination { } } + CUTLASS_HOST_DEVICE + explicit LinearCombination(const Params & params) + : LinearCombination(params, /* group_idx */ 0) { } + /// Returns true if source is needed CUTLASS_HOST_DEVICE bool is_source_needed() const { diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 48b66a1446..4a0c67ba14 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -39,11 +39,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -478,6 +474,12 @@ class Epilogue : // Iterate over accumulator tile // + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wcuda-compat" + // Turn off clangs warning about loop unroll argument using parens. + #endif + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { @@ -531,6 +533,10 @@ class Epilogue : destination_iterator.store(output_fragment); ++destination_iterator; } + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif } }; diff --git a/include/cutlass/epilogue/threadblock/epilogue_base.h b/include/cutlass/epilogue/threadblock/epilogue_base.h index 6853f5f042..30432e80eb 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -43,11 +43,7 @@ #include #endif -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/matrix_shape.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h index 43b14c356c..486c03040a 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h @@ -38,11 +38,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h b/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h index 2be1fa55a1..85ddae7cbd 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h +++ b/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h @@ -38,11 +38,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h index 9efbee4770..aff0548543 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h @@ -39,11 +39,11 @@ #pragma once -#if defined(__CUDACC_RTC__) #include + +#if defined(__CUDACC_RTC__) #include #else -#include #include #endif diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h b/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h index 9bae7a742a..df5bbc5c0e 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h @@ -50,11 +50,11 @@ #pragma once -#if defined(__CUDACC_RTC__) #include + +#if defined(__CUDACC_RTC__) #include #else -#include #include #endif diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 7e6d2a698b..d69f43c4a5 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -39,11 +39,11 @@ #pragma once -#if defined(__CUDACC_RTC__) #include + +#if defined(__CUDACC_RTC__) #include #else -#include #include #endif diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h index 1d4c7016b9..7f82bac7e8 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -39,11 +39,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/array.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h b/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h index 259f0729ca..027830c299 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h @@ -303,6 +303,12 @@ class EpilogueWithVisitorCallbacks : // Pipeline Loop // + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wcuda-compat" + // Turn off clang warning about loop unroll argument using parens. + #endif + #pragma unroll(IterationsUnroll ? kIterations : 1) for (int iter_idx = 1; iter_idx < kIterations + 1; ++iter_idx) { @@ -377,8 +383,19 @@ class EpilogueWithVisitorCallbacks : callbacks.end_step(iter_idx-1); } + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif + } else { + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wcuda-compat" + // Turn off clang warning about loop unroll argument using parens. + #endif + #pragma unroll(IterationsUnroll ? kIterations : 1) for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { @@ -459,6 +476,11 @@ class EpilogueWithVisitorCallbacks : callbacks.end_step(iter_idx); } + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif + } callbacks.end_epilogue(); diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp index 7a332f11fb..28d482b704 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp @@ -335,7 +335,8 @@ struct VisitorAuxLoad{ template< class ThreadMap, class Element, - class StrideMNL + class StrideMNL, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct VisitorRowBroadcast { @@ -399,6 +400,16 @@ struct VisitorRowBroadcast { CUTLASS_DEVICE void begin_epilogue() { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_row == nullptr) { + auto tC_rRow_vec = recast>(coalesce(tC_rRow)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tC_rRow_vec); ++i) { + tC_rRow_vec[i].fill(params_ptr->null_default); + } + return; + } + } clear(tC_rRow); auto src_v = filter(tC_gRow); auto coord_v = filter(tC_cRow); @@ -406,7 +417,7 @@ struct VisitorRowBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src_v); ++i) { bool guard = get<1>(coord_v(i)) < n; - cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + cutlass::arch::global_load(dst_v(i), (void const *)&src_v(i), guard); } } @@ -464,7 +475,8 @@ struct VisitorRowBroadcast { template< class ThreadMap, class Element, - class StrideMNL = Stride<_1,_0,_0> + class StrideMNL = Stride<_1,_0,_0>, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct VisitorColBroadcast { @@ -523,6 +535,12 @@ struct VisitorColBroadcast { CUTLASS_DEVICE void begin_epilogue() { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_col == nullptr) { + fill(tC_rCol, params_ptr->null_default); + return; + } + } clear(tC_rCol); Tensor pred = make_tensor(shape(tC_gCol)); CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp index 1c24e22d5f..dcec7ac83c 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp @@ -519,10 +519,7 @@ struct VisitorRowReduction { // Guard against uses of the existing SMEM tile __syncthreads(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rSrc); ++i) { - copy_vec(filter(tRS_rSrc), filter(tRS_sRows)); - } + copy(tRS_rSrc, tRS_sRows); __syncthreads(); diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index 617b8e39fe..8a88c0abc3 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -391,7 +391,7 @@ struct OutputTileOptimalThreadMap { 1>; /// Initial offset function - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static MatrixCoord initial_offset(int thread_idx) { // int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); @@ -462,7 +462,7 @@ struct OutputTileOptimalThreadMap { static int const kThreads = Threads; /// Function to compute each thread's initial offset - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static MatrixCoord initial_offset(int thread_idx) { // int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index c512dd873b..3322a4c65c 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -212,15 +212,23 @@ class TileIteratorTensorOpMixed { // When the optimization is enabled, small tiles require separate logic. bool kN32_optimization = (WarpShape::kN * Detail::kLanesInQuad * Policy::kElementsPerAccess * sizeof_bits::value) % 1024 == 0; if (kN32_optimization) { + int ptr_idx = ((warp_column_ * sizeof_bits::value) / 1024) % Detail::kPointerCount; + if (ptr_idx == 0) { ptr = pointers_[0]; } else if (ptr_idx == 1) { - ptr = pointers_[1]; + if constexpr (AccessType::kElements >= 2) { + ptr = pointers_[1]; + } } else if (ptr_idx == 2) { - ptr = pointers_[2]; + if constexpr (AccessType::kElements >= 3) { + ptr = pointers_[2]; + } } else if (ptr_idx == 3) { - ptr = pointers_[3]; + if constexpr (AccessType::kElements >= 4) { + ptr = pointers_[3]; + } } } diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 38ea4008c2..cfb6b8bbb8 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -1053,8 +1053,8 @@ float_e5m2_t::float_e5m2_t(float_e4m3_t x) { /// datatype in runtime argument list. /// /// Currently supported runtime datatypes compatible with type_erased_dynamic_float8_t: -/// QMMAFormat::E5M2 -/// QMMAFormat::E4M3 +/// MXF8F6F4Format::E5M2 +/// MXF8F6F4Format::E4M3 /// /////////////////////////////////////////////////////////////// diff --git a/include/cutlass/floating_point_nvrtc.h b/include/cutlass/floating_point_nvrtc.h index fdbd80fcdd..c08396aa78 100644 --- a/include/cutlass/floating_point_nvrtc.h +++ b/include/cutlass/floating_point_nvrtc.h @@ -35,6 +35,12 @@ #pragma once +#include // CUTLASS_HOST_DEVICE +#include // uint32_t +#if !defined(__CUDACC_RTC__) +#include // std::memcpy +#endif + namespace cutlass { /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 5b2bc3c67f..3c4d5c76ba 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -50,7 +50,7 @@ #ifdef _MSC_VER // Provides support for alternate operators such as 'and', 'or', ... -#include +#include #endif // _MSC_VER namespace cutlass { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 64e27a8d8a..f58fde8803 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -35,6 +35,8 @@ #include "cutlass/pipeline/sm90_pipeline.hpp" #include "cutlass/gemm/collective/collective_mma_decl.hpp" #include "cutlass/gemm/collective/collective_builder_decl.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" // SM90 Collective Builders should be used only starting CUDA 12.0 #if (__CUDACC_VER_MAJOR__ >= 12) @@ -236,8 +238,9 @@ struct CollectiveBuilder< GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes; + + static constexpr int Sm90ReducedSmemCapacityBytes = + detail::sm90_smem_capacity_bytes; static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); @@ -343,7 +346,7 @@ public: return t; } else { - return cute::stride(t); + return cute::stride(t); } } @@ -415,15 +418,15 @@ public: static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; - static constexpr int PipelineStages = IsMixedInput ? - detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) : - detail::compute_stage_count_or_override(StageCountType{}); + static constexpr int PipelineStages = IsMixedInput ? + detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) + : detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaRmemAWarpSpecialized>; + MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput + , MainloopSm90TmaGmmaRmemAWarpSpecialized>; using SmemCopyAtomA = cute::conditional_t>; using SmemCopyAtomB = cute::conditional_t, void>; @@ -761,13 +764,13 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = cute::is_same_v ? 2 : 1; using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; - using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; - using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); @@ -867,13 +870,13 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = 1; using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; - using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; - using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); diff --git a/include/cutlass/gemm/collective/collective_builder_decl.hpp b/include/cutlass/gemm/collective/collective_builder_decl.hpp index c0570d37a9..c27a84f21e 100644 --- a/include/cutlass/gemm/collective/collective_builder_decl.hpp +++ b/include/cutlass/gemm/collective/collective_builder_decl.hpp @@ -54,6 +54,18 @@ struct StageCountAutoCarveout { explicit StageCountAutoCarveout(cute::Int) {} }; +namespace detail { + +// Forward Declaration +template +constexpr int +compute_carveout_from_epi(); + +} // namespace detail + +template +struct StageCountAutoCarveoutEpi : StageCountAutoCarveout()> {}; + using StageCountAuto = StageCountAutoCarveout<0>; // Used to automatically let the builder pick the kernel schedule. diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 103da9af7b..21f8a557be 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -41,9 +41,10 @@ #include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000..ed223a56b8 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1370 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule_, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ +public: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput; + using TileShape = TileShape_; + using KernelSchedule = KernelSchedule_; + +private: + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + // + // Type Aliases + // + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + using StrideScale = cute::Stride, int64_t, int64_t>; + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)) || + (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + ((cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + using SwappedSmemLayoutAtomA = cute::conditional_t; + using SwappedSmemLayoutAtomB = cute::conditional_t; + using SwappedSmemCopyAtomA = cute::conditional_t; + using SwappedSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using SwappedTransformA = cute::conditional_t; + using SwappedTransformB = cute::conditional_t; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + /// Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalSwappedStrideA{})); + using SmemLayoutB = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalSwappedStrideB{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage { + CUTE_ALIGNAS(SmemAlignmentA) cute::ArrayEngine> smem_A; + CUTE_ALIGNAS(SmemAlignmentB) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + struct TensorMapStorage { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_scale; + cute::TmaDescriptor smem_tensormap_zero; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementScale const** ptr_S = nullptr; + NonVoidStrideScale const* dS{}; + int chunk_size = 0; + ElementZero const** ptr_Z = nullptr; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using LayoutA = decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideA{}, int32_t(0)), InternalSwappedStrideA{})); + using LayoutB = decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideB{}, int32_t(0)), InternalSwappedStrideB{})); + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + void* tensormaps; + SwappedElementA const** ptr_A; + SwappedStrideA ptr_dA; + SwappedElementB const** ptr_B; + SwappedStrideB ptr_dB; + NonVoidElementScale const** ptr_S; + NonVoidStrideScale const* dS; + NonVoidElementZero const** ptr_Z; + int64_t scale_k; + int chunk_size; + int reload_factor = (chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + + if constexpr (SwapAB) { + init_M = get<1>(init_shape); + init_N = get<0>(init_shape); + } + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t mock_L = 1; + SwappedElementA const* ptr_A_first_batch; + SwappedElementB const* ptr_B_first_batch; + SwappedStrideA ptr_dA; + SwappedStrideB ptr_dB; + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + + if constexpr (not SwapAB) { + ptr_A_first_batch = reinterpret_cast(args.ptr_A); + ptr_B_first_batch = reinterpret_cast(args.ptr_B); + } + else { + ptr_A_first_batch = reinterpret_cast(args.ptr_B); + ptr_B_first_batch = reinterpret_cast(args.ptr_A); + } + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + if constexpr (not SwapAB) { + ptr_dA = args.dA; + ptr_dB = args.dB; + } + else { + ptr_dA = args.dB; + ptr_dB = args.dA; + } + dA = InternalSwappedStrideA{}; + if constexpr (is_layout::value) { + dA = make_layout( + transform_leaf(dA.shape(), [](auto x){ + if constexpr (not is_static_v) { + return static_cast(1); + } else { + return x; + } + }), + dA.stride()); + } + dB = InternalSwappedStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + if constexpr (not SwapAB) { + dA = args.dA; + dB = args.dB; + } + else { + dA = args.dB; + dB = args.dA; + } + ptr_dA = SwappedStrideA{}; + ptr_dB = SwappedStrideB{}; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, detail::get_gmem_layout(make_shape(init_M,init_K,mock_L), dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, detail::get_gmem_layout(make_shape(init_N,init_K,mock_L), dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + void* tensormaps = workspace; + auto args_setup = [&](auto ptr_A, auto ptr_B, int64_t scale_k = 0, int chunk_size = 0, int reload_factor = 1) -> Params { + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tma_load_scale, + tma_load_zero, + tensormaps, + reinterpret_cast(ptr_A), + ptr_dA, + reinterpret_cast(ptr_B), + ptr_dB, + reinterpret_cast(args.ptr_S), + args.dS, + reinterpret_cast(args.ptr_Z), + scale_k, + chunk_size, + reload_factor, + dA, + dB + }; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A) + : args_setup(args.ptr_A, args.ptr_B); + } + else if constexpr (ModeHasScales) { + // NOTE: fix chunk wise scaling + //auto scale_k = (K + args.chunk_size - 1) / args.chunk_size; + auto scale_k = 1; + ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); + StrideScale dS{}; + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS)); + tma_load_scale = make_tma_copy( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = reinterpret_cast(args.ptr_Z); + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_layout(make_shape(init_M,scale_k,mock_L), dS)); + tma_load_zero = make_tma_copy( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + + // Calculating workspace size + auto calculate_workspace_size = [SizeOfCuTensorMap, sm_count](uint32_t num_input_tensors) { + return num_input_tensors * SizeOfCuTensorMap * sm_count; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return calculate_workspace_size(2); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale tensormap copies + return calculate_workspace_size(3); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale and zeros tensormap copies + return calculate_workspace_size(4); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in get_workspace_size."); + } + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto get_stride = [](auto stride) { + if constexpr (cute::is_pointer_v>) { + return *stride; + } + else { + return stride; + } + }; + auto dA = get_stride(args.dA); + auto dB = get_stride(args.dB); + implementable = implementable && cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(M,K,L), dA)); + implementable = implementable && cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(N,K,L), dB)); + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + const int scale_k = (K + args.chunk_size - 1) / args.chunk_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.chunk_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = Utils::compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = Utils::compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytesExtra = Utils::compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(M,K,mock_L), mainloop_params.dA))); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(N,K,mock_L), mainloop_params.dB))); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } + else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class... Ts, + class... TMs, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + static_assert(sizeof... (TMs) == 2, "Direct convert needs two tensormaps"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + static_assert(sizeof... (TMs) == 3, "Scaled convert needs three tensormaps"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + static_assert(sizeof... (TMs) == 4, "Scaled and zero convert needs four tensormaps"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = Utils::partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with chunk_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when chunk_size == K. + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_scale.with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_zero.with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SwappedSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SwappedSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = [&]{ + if constexpr (not is_layout::value) { + // Make register tensor with MMA layout + return make_fragment_like(tCrA_mma); + } + else { + // Make register tensor matching smem layout, converter will take care of de-swizzling + return make_tensor_like(tCsA(_,_,_,Int<0>{})); + } + }(); + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SwappedSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = Utils::partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = Utils::retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + constexpr int K_WAIT_MAX = cute::min(K_BLOCK_MAX - 1, 7); + static_assert(K_BLOCK_MAX >= 4, "Consider increasing TileShapeK"); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + if (K_BLOCK_MAX > 1) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + + // NOTE: Check this when applying swizzling PR on top of GGMD + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + + warpgroup_wait(); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + + // NOTE: Check this when applying swizzling PR on top of GGMD + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + else { + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_scale = &gmem_tensormap[sm_idx + 2*sm_count]; + cute::TmaDescriptor* tma_desc_zero = &gmem_tensormap[sm_idx + 3*sm_count]; + + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor pS_tensormap = make_tensor(mainloop_params.tma_load_scale.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sS_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_scale), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pS_tensormap), recast(sS_tensormap)); + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor pZ_tensormap = make_tensor(mainloop_params.tma_load_zero.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sZ_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_zero), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pZ_tensormap), recast(sZ_tensormap)); + } + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + + __syncwarp(); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale, tma_desc_zero); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_scale, + mainloop_params.ptr_S[next_batch]); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_zero, + mainloop_params.ptr_Z[next_batch]); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_address."); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_scale = {1,1,1,1,1}; + cute::array prob_stride_scale = {0,0,0,0,0}; + cute::array prob_shape_zero = {1,1,1,1,1}; + cute::array prob_stride_zero = {0,0,0,0,0}; + + SwappedElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, detail::get_gmem_layout(make_shape(M,K,Int<1>{}), mainloop_params.ptr_dA[next_group])); + + SwappedElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, detail::get_gmem_layout(make_shape(N,K,Int<1>{}), mainloop_params.ptr_dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + NonVoidElementScale const* ptr_S = nullptr; + // NOTE: figure out chunk wise scaling. auto scale_k = (K + mainloop_params.chunk_size - 1) / mainloop_params.chunk_size; + auto scale_k = 1; + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, + prob_shape_scale, prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = nullptr; + // NOTE: figure out chunk wise scaling. auto scale_k = (K + mainloop_params.chunk_size - 1) / mainloop_params.chunk_size; + auto scale_k = 1; + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, + prob_shape_zero, prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_scale) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_zero) { + stride = (stride * sizeof_bits_v) / 8; + } + + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_scale, + prob_shape_scale, + prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_zero, + prob_shape_zero, + prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_cp_fence_release."); + } + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_fence_acquire."); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 628750fc3a..5264aa4c7f 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -374,7 +374,7 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(load_inputs); diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 73564d3c65..5c6c2a0f08 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -85,13 +85,40 @@ class GemmUniversalAdapter; ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } + else { + return 0; + } +} + +} // namespace detail + template class GemmUniversalAdapter< GemmKernel_, - cute::enable_if_t::value>> + cute::enable_if_t>::value>> { public: - using GemmKernel = GemmKernel_; + using GemmKernel = GetUnderlyingKernel_t; using TileShape = typename GemmKernel::TileShape; using ElementA = typename GemmKernel::ElementA; using ElementB = typename GemmKernel::ElementB; @@ -158,7 +185,7 @@ class GemmUniversalAdapter< CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; - static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); // Inspect TiledCopy for A and B to compute the alignment size static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< @@ -336,7 +363,7 @@ class GemmUniversalAdapter< } /// Primary run() entry point API that is static allowing users to create and manage their own params. - /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() static Status run(Params& params, cudaStream_t stream = nullptr, @@ -358,10 +385,10 @@ class GemmUniversalAdapter< [[maybe_unused]] constexpr bool is_static_1x1x1 = cute::is_static_v and cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; - dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - void* kernel_params[] = {¶ms}; + [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + [[maybe_unused]] void* kernel_params[] = {¶ms}; if constexpr (kEnableCudaHostAdapter) { // @@ -377,13 +404,23 @@ class GemmUniversalAdapter< #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); #endif - launch_result = cuda_adapter->launch(grid, - cluster, - block, - smem_size, - stream, - kernel_params, - 0); + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } } else { CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); @@ -392,8 +429,10 @@ class GemmUniversalAdapter< } else { CUTLASS_ASSERT(cuda_adapter == nullptr); - void const* kernel = (void const*) device_kernel; - if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { + [[maybe_unused]] void const* kernel = (void const*) device_kernel; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90 + ; + if constexpr (kClusterLaunch) { if constexpr (is_static_1x1x1) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); @@ -526,11 +565,11 @@ class GemmUniversalAdapter< template class GemmUniversalAdapter< GemmKernel_, - cute::enable_if_t::value>> + cute::enable_if_t>::value>> { public: - using GemmKernel = GemmKernel_; + using GemmKernel = GetUnderlyingKernel_t; static bool const kInternalTranspose = !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index fa275bdae1..236a1227c2 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -105,7 +105,8 @@ struct KernelCpAsyncWarpSpecializedPingpong { }; struct KernelCpAsyncWarpSpecializedCooperative { }; struct KernelTma { }; struct KernelTmaWarpSpecialized { }; -struct KernelTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedPingpong { +}; struct KernelTmaWarpSpecializedCooperative { }; @@ -247,6 +248,7 @@ struct MainloopSm90TmaGmmaRmemAWarpSpecialized { "KernelSchedule must be one of the warp specialized policies"); }; + template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, @@ -310,6 +312,7 @@ struct MainloopSm90TmaGmmaWarpSpecializedSparse { using Schedule = KernelSchedule; }; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp index 4a90a1d06d..fbc0fdd715 100644 --- a/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -69,7 +69,7 @@ struct GroupProblemShape { CUTLASS_HOST_DEVICE UnderlyingProblemShape const get_host_problem_shape(int32_t group_idx) const { - return host_problem_shapes[group_idx]; + return host_problem_shapes != nullptr ? host_problem_shapes[group_idx] : UnderlyingProblemShape{}; } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h b/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h new file mode 100644 index 0000000000..3b7b126ae4 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, + ElementAccumulator>::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void + > +struct DefaultGemmGroupedPerGroupScale; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout +> +struct DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + GroupScheduleMode_, + Operator, + SharedMemoryClear, + PermuteDLayout, + typename platform::enable_if< ! cutlass::is_complex::value>::type +> { + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments< + ElementA, + LayoutA, + ComplexTransform::kNone, + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, + kAlignmentB, + LayoutC, + kInternalTranspose + >; + + // Define the default GEMM kernel + using DefaultGemmKernel = typename kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementC, + typename MapArguments::LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + true, + Operator, + SharedMemoryClear, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + PermuteDLayout + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmGroupedPerGroupScale< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Complex-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear + > +struct DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + GroupScheduleMode_, + Operator, + SharedMemoryClear, + layout::NoPermute, /*PermuteDLayout*/ + typename platform::enable_if::value>::type +> { + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + LayoutC, + kInternalTranspose + >; + + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + ElementC, + typename MapArguments::LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MapArguments::kTransformA, + MapArguments::kTransformB, + Operator, + false + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmGroupedPerGroupScale< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/ell_gemm.h b/include/cutlass/gemm/kernel/ell_gemm.h index 7cd6198021..aad3295925 100644 --- a/include/cutlass/gemm/kernel/ell_gemm.h +++ b/include/cutlass/gemm/kernel/ell_gemm.h @@ -691,7 +691,7 @@ struct EllGemm { static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); + constexpr bool is_double = (sizeof(typename Mma::IteratorA::Element) == 8); constexpr bool is_multiple_alignment = (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); const bool is_specialized_blocksize = @@ -699,11 +699,11 @@ struct EllGemm { && params.ell_blocksize >= Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { - mma.operator()( + mma.template operator()( gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); } else { - mma.operator()( + mma.template operator()( gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); } } diff --git a/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h b/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h new file mode 100644 index 0000000000..972681ab38 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Problem visitor for grouped GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform + bool Transposed = false +> +struct GemmGroupedPerGroupScale : + public GemmGrouped { + + // Inherit constructors + using Base = GemmGrouped; + + // Inherit type definitions + using typename Base::Mma; + using typename Base::Epilogue; + using typename Base::EpilogueOutputOp; + using typename Base::ThreadblockSwizzle; + using typename Base::Params; + using typename Base::SharedStorage; + + // Explicitly inherit the kTransposed constant + static bool const kTransposed = Base::kTransposed; + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + typename Base::ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_offset.n() + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k(), problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + ElementC *ptr_C = params.ptr_C[problem_idx]; + ElementC *ptr_D = params.ptr_D[problem_idx]; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset.mn() + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size.mn(), + thread_idx, + threadblock_offset.mn() + ); + + Epilogue epilogue( + shared_storage.kernel.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // The if branch is for the per-group scaling epilogue. The customized epilogue operator scales each gemm output by a scalar value. + // This branch is only enabled if EpilogueOutputOp is LinearCombination. + if constexpr (platform::is_same>::value) + { + EpilogueOutputOp output_op(params.output_op, problem_idx); + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + } else { + EpilogueOutputOp output_op(params.output_op); + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h index cdb8259930..5d8ce78908 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h @@ -437,7 +437,7 @@ class GemmWithEpilogueVisitorStreamk { int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; int m_end = params.block_mapping.problem_size.m(); - return Mma::IteratorA( + return typename Mma::IteratorA( params.params_A, ptr_A, { m_end, tile_work.k_end }, @@ -466,7 +466,7 @@ class GemmWithEpilogueVisitorStreamk { int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; int n_end = params.block_mapping.problem_size.n(); - return Mma::IteratorB( + return typename Mma::IteratorB( params.params_B, ptr_B, { tile_work.k_end, n_end }, diff --git a/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/include/cutlass/gemm/kernel/grouped_problem_visitor.h index 31787372a3..4df76ec0bd 100644 --- a/include/cutlass/gemm/kernel/grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/grouped_problem_visitor.h @@ -66,10 +66,10 @@ struct BaseGroupedProblemVisitor { int32_t problem_idx; int32_t problem_start; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE ProblemInfo(int32_t problem_idx_, int32_t problem_start_) : problem_idx(problem_idx_), problem_start(problem_start_) {} }; diff --git a/include/cutlass/gemm/kernel/params_universal_base.h b/include/cutlass/gemm/kernel/params_universal_base.h index 86986f2e21..172855edf4 100644 --- a/include/cutlass/gemm/kernel/params_universal_base.h +++ b/include/cutlass/gemm/kernel/params_universal_base.h @@ -182,7 +182,7 @@ struct UniversalParamsBase CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes"); cudaError_t result = cudaMemsetAsync( - semaphore, + static_cast(workspace), 0, workspace_bytes, stream); diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index 6b36db21a4..e8383faf19 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -479,14 +479,14 @@ struct Rank2KGrouped { // Construct iterators to A and B operands for Mma1 typename Mma1::IteratorA iterator_A( - Mma1::IteratorA::Params(ldm_A), + typename Mma1::IteratorA::Params(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_MxK); typename Mma1::IteratorB iterator_BT( - Mma1::IteratorB::Params(ldm_B), + typename Mma1::IteratorB::Params(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, @@ -494,14 +494,14 @@ struct Rank2KGrouped { // Construct iterators to A and B operands for Mma2 typename Mma2::IteratorA iterator_B( - Mma2::IteratorA::Params(ldm_B), + typename Mma2::IteratorA::Params(ldm_B), ptr_B, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_MxK); typename Mma2::IteratorB iterator_AT( - Mma2::IteratorB::Params(ldm_A), + typename Mma2::IteratorB::Params(ldm_A), ptr_A, {problem_size_k, problem_size.n()}, thread_idx, @@ -560,7 +560,7 @@ struct Rank2KGrouped { // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), ptr_C, problem_size.mn(), thread_idx, @@ -570,7 +570,7 @@ struct Rank2KGrouped { // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), ptr_D, problem_size.mn(), thread_idx, @@ -634,7 +634,7 @@ struct Rank2KGrouped { // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), ptr_C, problem_size.mn(), thread_idx, @@ -644,7 +644,7 @@ struct Rank2KGrouped { // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), ptr_D, problem_size.mn(), thread_idx, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 823e919ed1..c0c10b97b7 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -218,11 +218,6 @@ class GemmUniversal< uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -231,6 +226,11 @@ class GemmUniversal< workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + TileSchedulerParams scheduler; if constexpr (IsGroupedGemmKernel) { scheduler = TileScheduler::to_underlying_arguments( @@ -276,10 +276,6 @@ class GemmUniversal< size_t workspace_size = 0; constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { @@ -294,6 +290,10 @@ class GemmUniversal< workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; } @@ -306,23 +306,25 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } @@ -633,7 +635,7 @@ class GemmUniversal< constexpr bool IsEpiLoad = true; if (work_tile_info.is_valid()) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -644,7 +646,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } load_order_barrier.wait(); @@ -667,7 +669,7 @@ class GemmUniversal< auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; @@ -697,7 +699,7 @@ class GemmUniversal< // tensormap update { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -708,7 +710,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } } @@ -738,7 +740,7 @@ class GemmUniversal< if (work_tile_info.is_valid()) { if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -749,8 +751,8 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, - epi_store_tensormap, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, consumer_warp_group_idx); } } @@ -805,7 +807,7 @@ class GemmUniversal< params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { @@ -843,7 +845,7 @@ class GemmUniversal< problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -854,7 +856,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 386337641d..1b7c0cb412 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -226,11 +226,6 @@ class GemmUniversal< uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -239,6 +234,11 @@ class GemmUniversal< workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means // subtile will not be used, therefore separate reduction will not be enabled. @@ -288,10 +288,6 @@ class GemmUniversal< size_t workspace_size = 0; constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { @@ -306,6 +302,10 @@ class GemmUniversal< workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; } @@ -318,27 +318,28 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { return status; } + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } return status; } @@ -666,7 +667,7 @@ class GemmUniversal< constexpr bool IsEpiLoad = true; if (work_tile_info.is_valid()) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -677,7 +678,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } load_order_barrier.wait(); @@ -700,7 +701,7 @@ class GemmUniversal< auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; @@ -730,7 +731,7 @@ class GemmUniversal< // tensormap update { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -741,7 +742,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } } @@ -771,7 +772,7 @@ class GemmUniversal< if (work_tile_info.is_valid()) { if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -782,7 +783,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); } @@ -844,7 +845,7 @@ class GemmUniversal< params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { @@ -897,7 +898,7 @@ class GemmUniversal< problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -908,7 +909,7 @@ class GemmUniversal< // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 243a9e7083..0dece13924 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -51,8 +51,6 @@ namespace cutlass::gemm::kernel { -/////////////////////////////////////////////////////////////////////////////// - template < class ProblemShape_, class CollectiveMainloop_, @@ -107,7 +105,6 @@ class GemmUniversal< TileShape, ClusterShape >::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; @@ -122,7 +119,8 @@ class GemmUniversal< static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - + static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups; + /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; @@ -207,22 +205,23 @@ class GemmUniversal< uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* mainloop_workspace = nullptr; // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means // subtile will not be used, therefore separate reduction will not be enabled. constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles + ); return { args.mode, @@ -254,13 +253,12 @@ class GemmUniversal< size_t workspace_size = 0; constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); return workspace_size; } @@ -273,17 +271,17 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; @@ -314,6 +312,7 @@ class GemmUniversal< operator()(Params const& params, char* smem_buf) { using namespace cute; using X = Underscore; + #if defined(__CUDA_ARCH_FEAT_SM90_ALL) # define ENABLE_SM90_KERNEL_LEVEL 1 #endif @@ -487,7 +486,6 @@ class GemmUniversal< // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); collective_mainloop.load( @@ -581,11 +579,10 @@ class GemmUniversal< auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - // Allocate the accumulators for the (M,N) blk_shape // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { collective_mainloop.mma( mainloop_pipeline, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index cf4a552cb1..c19a8e9f8c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -105,14 +105,24 @@ class GemmUniversal< static_assert(!cute::is_same_v, "Ping-pong kernel does not currently support stream-K scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + TileSchedulerTag, + ArchTag, + TileShape, + ClusterShape + >::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; - + // Warp specialization thread count per threadblock + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 4 warp + static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads * NumMmaWarpGroups + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static_assert(NumMMAThreads == 128, "Pingpong kernel must have TiledMMA operating using 128 threads."); + static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total."); /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; @@ -142,7 +152,7 @@ class GemmUniversal< alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; } pipelines; - + struct TensorStorage : cute::aligned_struct<128, _1> { using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -208,16 +218,17 @@ class GemmUniversal< uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* mainloop_workspace = nullptr; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); return { args.mode, @@ -225,7 +236,9 @@ class GemmUniversal< CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace) + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles + ) }; } @@ -247,13 +260,14 @@ class GemmUniversal< static size_t get_workspace_size(Arguments const& args) { size_t workspace_size = 0; - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; } @@ -266,17 +280,17 @@ class GemmUniversal< static constexpr uint32_t NumEpilogueSubTiles = 1; static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; @@ -308,9 +322,12 @@ class GemmUniversal< using namespace cute; using X = Underscore; +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -350,6 +367,7 @@ class GemmUniversal< CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } + // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; @@ -450,8 +468,8 @@ class GemmUniversal< auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); TileScheduler scheduler{params.scheduler}; - if (warp_group_role == WarpGroupRole::Consumer1) { + // Advance 2nd Math WG to the next work tile for the startup scheduler.advance_to_next_work(); // Advance 2nd Math WG pipeline states to the end of 1st Math WG @@ -466,7 +484,7 @@ class GemmUniversal< if (warp_group_role == WarpGroupRole::Producer) { cutlass::arch::warpgroup_reg_dealloc(); - + // Mainloop Producer Warp if (producer_warp_role == ProducerWarpRole::Mainloop) { // Ensure that the prefetched kernel does not touch @@ -546,6 +564,7 @@ class GemmUniversal< // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End } // Producer Warp Group End @@ -564,7 +583,7 @@ class GemmUniversal< return; } #endif - + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index 5e61e7c99d..08437c70c5 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -29,8 +29,8 @@ * **************************************************************************************************/ #pragma once -#include "cutlass/gemm/kernel/static_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/static_tile_scheduler.hpp" namespace cutlass::gemm::kernel::detail { diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index 888be276d5..a30d9ce08b 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -337,12 +337,16 @@ class PersistentTileSchedulerSm90Group { uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx); divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + // With static schedulers, we launch grid such that all cluster are linear (1-D) order, i.e., + // there can only be one cluster in the minor dimension. get_grid_shape() in scheduler params + // put cluster_shape.m/n() as the minor dimension based on raster order AlongN/M resp. + // Therefore, the offset of a CTA (inside a cluster) in the minor dimension can be directly be + // inferred by the blockIdx along the minor dimension. if (raster_order == RasterOrder::AlongN) { - cluster_minor_offset = cta_m_in_cluster; + cluster_minor_offset = blockIdx.x; } else { - cluster_minor_offset = cta_n_in_cluster; + cluster_minor_offset = blockIdx.y; } uint64_t cluster_idx_minor, cluster_idx_major; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index 80b374ad7b..b5e62164da 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -58,7 +58,9 @@ class PersistentTileSchedulerSm90StreamK { using UnderlyingArguments = typename UnderlyingScheduler::Arguments; using UnderlyingParams = typename UnderlyingScheduler::Params; + dim3 block_id_in_cluster_; uint64_t current_work_linear_idx_ = 0; + uint32_t unit_iter_start_ = 0; public: @@ -240,25 +242,26 @@ class PersistentTileSchedulerSm90StreamK { CUTLASS_HOST_DEVICE PersistentTileSchedulerSm90StreamK() { }; - CUTLASS_HOST_DEVICE - PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_) { + CUTLASS_DEVICE + PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_), block_id_in_cluster_(cute::block_id_in_cluster()) { if (params_.raster_order_ == RasterOrder::AlongN) { current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); } else { current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); } + } CUTLASS_DEVICE WorkTileInfo - get_current_work() const { - return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params); + get_current_work() { + return get_current_work_for_linear_idx(unit_iter_start_, current_work_linear_idx_, block_id_in_cluster_, scheduler_params); } CUTLASS_DEVICE static WorkTileInfo - get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) { + get_current_work_for_linear_idx(uint32_t &unit_iter_start, uint64_t linear_idx, dim3 block_id_in_cluster, Params const& params) { // The maximum number of work units is units_per_problem_ * splits_. // The multiplication by splits_ is used for handling split-K, in which // units_per_problem_ is equal to the total number of output tiles. To account @@ -271,7 +274,7 @@ class PersistentTileSchedulerSm90StreamK { } WorkTileInfo work_tile_info; - assign_work(params, linear_idx, work_tile_info); + assign_work(params, linear_idx, block_id_in_cluster, work_tile_info, unit_iter_start); return work_tile_info; } @@ -283,13 +286,15 @@ class PersistentTileSchedulerSm90StreamK { bool continue_current_work(WorkTileInfo& work_tile_info) const { return continue_current_work_for_linear_idx( - current_work_linear_idx_, work_tile_info, scheduler_params); + current_work_linear_idx_, unit_iter_start_, block_id_in_cluster_, work_tile_info, scheduler_params); } CUTLASS_DEVICE static bool continue_current_work_for_linear_idx( uint64_t linear_idx, + uint32_t unit_iter_start, + dim3 block_id_in_cluster, WorkTileInfo& work_tile_info, Params const& params) { @@ -298,7 +303,7 @@ class PersistentTileSchedulerSm90StreamK { if (work_tile_info.k_tile_remaining == 0) { return false; } - assign_work(params, linear_idx, work_tile_info); + fast_assign_work(unit_iter_start, params, linear_idx, block_id_in_cluster, work_tile_info); return work_tile_info.is_valid(); } @@ -316,9 +321,11 @@ class PersistentTileSchedulerSm90StreamK { return false; } return not get_current_work_for_linear_idx( + unit_iter_start_, current_work_linear_idx_ + ( uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count) ), + block_id_in_cluster_, scheduler_params ).is_valid(); } @@ -420,22 +427,24 @@ class PersistentTileSchedulerSm90StreamK { uint64_t reduction_tile_idx = tile_idx; uint64_t num_peers = 0; uint64_t reduction_peer_offset = 0; - if (params.requires_separate_reduction()) { + if ( + params.requires_separate_reduction() + ) { // If separate reduction is to be performed, each stream-K unit writes its partials // to a separate portion of the workspace. There are as many of these portions as there // are peers for a given output tile, so we multiply the tile index by the maximum peer count. - auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, static_cast(work_tile_info.K_idx)); + auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, work_tile_info); + auto peer_id_in_output_tile = my_peer_id - first_peer_id; num_peers = last_peer_id - first_peer_id + 1; - reduction_tile_idx *= Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); - reduction_peer_offset = my_peer_id * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}); + reduction_tile_idx = tile_idx * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); + reduction_peer_offset = peer_id_in_output_tile * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * num_accumulator_mtxs; } // Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood. // Thus, the start of the reduction space is the same across all threads in a warp group. - uint64_t reduction_offset = - (static_cast(cute::size<0>(TileShape{})) * static_cast(cute::size<1>(TileShape{})) * reduction_tile_idx * num_accumulator_mtxs) + - reduction_peer_offset + + uint64_t reduction_offset_base = (static_cast(cute::size<0>(TileShape{})) * static_cast(cute::size<1>(TileShape{})) * reduction_tile_idx * num_accumulator_mtxs) + (static_cast(size(accumulators)) * barrier_idx * BarrierManager::ThreadCount); + uint64_t reduction_offset = reduction_offset_base + reduction_peer_offset; ElementAccumulator* group_reduction_workspace = reinterpret_cast(params.reduction_workspace_) + reduction_offset; @@ -457,7 +466,9 @@ class PersistentTileSchedulerSm90StreamK { if (params.divmod_splits_.divisor > 1) { reduction_tiles = params.units_per_problem_; } - else if (params.requires_separate_reduction()) { + else if ( + params.requires_separate_reduction() + ) { reduction_tiles = params.sk_tiles_ * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); } else { @@ -470,29 +481,17 @@ class PersistentTileSchedulerSm90StreamK { reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); if (work_tile_info.is_reduction_unit()) { - plus add_fragments; - uint64_t peer_offset = size(accumulators) * num_barriers * BarrierManager::ThreadCount; - // Wait until the peers collaborating on this output tile have all written // their accumulators to workspace. BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, num_peers); - // Load the first peer's data - BlockStripedReduceT::load(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx); - - for (uint64_t i = 1; i < num_peers; ++i) { - // Load peer fragment - AccumulatorArrayT addend_fragment; - auto peer_reduction_workspace = reinterpret_cast(group_reduction_workspace + (i * peer_offset)); - - BlockStripedReduceT::load(addend_fragment, peer_reduction_workspace, barrier_group_thread_idx); - - // Add peer fragment - *accumulator_array = add_fragments(*accumulator_array, addend_fragment); - } + separate_reduction(accumulators, num_barriers, group_reduction_workspace, barrier_group_thread_idx, num_peers, num_accumulator_mtxs); } else if (!compute_epilogue(work_tile_info, params)) { - if (params.requires_separate_reduction() || work_tile_info.K_idx == 0) { + if ( + params.requires_separate_reduction() + || work_tile_info.K_idx == 0 + ) { // The first peer initializes the workspace partials in the non-separate-reduction case, // and all peers write to their own location in workspace when using separate reduction BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); @@ -513,12 +512,16 @@ class PersistentTileSchedulerSm90StreamK { BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment); } else { - if (params.reduction_mode_ == ReductionMode::Deterministic) { + if ( + params.reduction_mode_ == ReductionMode::Deterministic + ) { + // Wait until the preceding split added its accumulators BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + } else { - // Wait unitl the first split has stored its accumulators + // Wait until the first split has stored its accumulators BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); } @@ -528,6 +531,36 @@ class PersistentTileSchedulerSm90StreamK { } } + template + CUTLASS_DEVICE + static void + separate_reduction( + FrgTensorC& accumulators, + uint32_t num_barriers, + typename FrgTensorC::value_type* reduction_workspace, + uint32_t thread_idx, + uint64_t num_peers, + uint32_t num_accumulator_mtxs) { + using AccumulatorArrayT = Array; + using BlockStripedReduceT = BlockStripedReduce; + + AccumulatorArrayT* accumulator_array = reinterpret_cast(accumulators.data()); + + plus add_fragments; + uint64_t peer_offset = cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * num_accumulator_mtxs; + + for (uint64_t i = 0; i < num_peers; ++i) { + // Load peer fragment + AccumulatorArrayT addend_fragment; + auto peer_reduction_workspace = reinterpret_cast(reduction_workspace + (i * peer_offset)); + + BlockStripedReduceT::load(addend_fragment, peer_reduction_workspace, thread_idx); + + // Add peer fragment + *accumulator_array = add_fragments(*accumulator_array, addend_fragment); + } + } + // Returns whether the block assigned this work should compute the epilogue for the corresponding // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. CUTLASS_HOST_DEVICE @@ -587,6 +620,7 @@ class PersistentTileSchedulerSm90StreamK { args.max_swizzle_size, args.raster_order, args.decomposition_mode, + args.reduction_mode, mma_warp_groups, sizeof_bits::value, sizeof_bits::value, @@ -627,6 +661,7 @@ class PersistentTileSchedulerSm90StreamK { args.max_swizzle_size, args.raster_order, args.decomposition_mode, + args.reduction_mode, mma_warp_groups, sizeof_bits::value, sizeof_bits::value, @@ -668,224 +703,235 @@ class PersistentTileSchedulerSm90StreamK { return get_current_work(); } -private: - // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info - // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining - // iterations) is used to find the next tile in the current work unit. + // Given raster order and current work tile linear index, reset cta m and n index in the cluster. CUTLASS_DEVICE - static void - assign_work( + static dim3 + get_current_work_cta_m_n_in_cluster( Params const& params, uint64_t linear_idx, - WorkTileInfo& work_tile_info) { - - auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = cute::block_id_in_cluster(); + dim3 block_id_in_cluster) { + auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = block_id_in_cluster; uint64_t cta_m_in_cluster = static_cast(cta_m_in_cluster_); uint64_t cta_n_in_cluster = static_cast(cta_n_in_cluster_); - uint64_t output_tile_id = linear_idx; - if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) { - // Separate-reduction work - auto cluster_size = params.get_cluster_size(); - // Divide up the linearized separate reduction units into clusters - uint64_t cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_)); - uint64_t cluster_tile_idx, epi_subtile_idx; - params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx); - // Bring the linearized tile ID back into the space of tiles, rather than clusters - output_tile_id = cluster_tile_idx * cluster_size; + return {static_cast(cta_m_in_cluster), static_cast(cta_n_in_cluster), _}; + } - work_tile_info.setup_separate_reduction(epi_subtile_idx); +private: + + CUTLASS_DEVICE + static uint32_t + get_current_work_iter_start_possible_update_work_tile_k_remaining( + Params const& params, + uint64_t linear_idx, + WorkTileInfo& work_tile_info) { + // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K + // threadblock individually. For the most part, the set of K iterations corresponding to stream-K + // work was divided amongst stream-K threadblocks, and a threadblock determined which tile + // it would compute a (potentially-partial) output tile for based on the space of k iterations + // assigned to it. This often results in stream-K threadblocks processing tiles with different + // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the + // (generally few) waves of threadblocks assigned to compute stream-K work. + // + // With the introduction of threadblock clusters, there is additional benefit to maintaining + // locality in the K dimension: shared portions of operands can be multicasted to threadblocks + // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to + // threadblocks respects the ability to perform multicasting. + // + // To do so, we divide up the linearized stream-K units into clusters and share the same K + // offsets for work within clusters. + uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); + + uint64_t group_idx; + params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx); + + // Determine whether we are in a "big group" that will process an additional + // stream-K cluster tile. + uint64_t sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_); + uint64_t sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles); + if (group_idx < params.big_groups_) { + ++sk_cluster_tiles_in_group; } - else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { - // Data-parallel work - output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; - work_tile_info.K_idx = 0; - work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; - work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; + + // Determine whether we are in a "big unit" within the group, that will process + // an additional K chunk in the group. + uint64_t sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size(); + uint64_t k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; + uint64_t k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); + uint64_t big_units_in_group = params.div_cluster_size( + k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor)); + + uint64_t split; + params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); + + bool is_split_k = params.divmod_splits_.divisor > 1; + uint64_t big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; + uint64_t big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; + uint64_t linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; + uint64_t k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; + + // Determine the starting k iteration computed by this stream-K work unit + uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + + (k_tiles_per_split * split); + + // Adjust the starting position and number of k iterations for "big units," which + // compute one extra iteration. If there are any big units, they will be the first + // in the linearized ID space. + auto k_tiles_in_my_split = k_tiles_per_split; + if (big_unit_cmp_lhs < big_unit_cmp_rhs) { + // Since the "big units" are the first units in the linearized ID space, each + // of the units preceding this big unit computed one extra iteration. Thus, + // we must offset our start iteration by the number of units that precede + // the current unit in the linearized ID space. + unit_iter_start += big_unit_cmp_lhs; + ++k_tiles_in_my_split; } else { - // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K - // threadblock individually. For the most part, the set of K iterations corresponding to stream-K - // work was divided amongst stream-K threadblocks, and a threadblock determined which tile - // it would compute a (potentially-partial) output tile for based on the space of k iterations - // assigned to it. This often results in stream-K threadblocks processing tiles with different - // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the - // (generally few) waves of threadblocks assigned to compute stream-K work. - // - // With the introduction of threadblock clusters, there is additional benefit to maintaining - // locality in the K dimension: shared portions of operands can be multicasted to threadblocks - // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to - // threadblocks respects the ability to perform multicasting. - // - // To do so, we divide up the linearized stream-K units into clusters and share the same K - // offsets for work within clusters. - - uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); - - uint64_t group_idx; - params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx); - - // Determine whether we are in a "big group" that will process an additional - // stream-K cluster tile. - uint64_t sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_); - uint64_t sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles); - if (group_idx < params.big_groups_) { - ++sk_cluster_tiles_in_group; + // Increment by one for each of the big clusters (since all big units precede this unit) + unit_iter_start += big_unit_cmp_rhs; + } + if (!is_split_k) { + // Adjust the unit starting position and number of tiles to avoid + // computing splits of size less than min_iters_per_sk_unit_ + int unused, start_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); + if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another + // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take over these K tiles. + unit_iter_start -= start_tile_k_tile; + k_tiles_in_my_split += start_tile_k_tile; } - - // Determine whether we are in a "big unit" within the group, that will process - // an additional K chunk in the group. - uint64_t sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size(); - uint64_t k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; - uint64_t k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); - uint64_t big_units_in_group = params.div_cluster_size( - k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor)); - - uint64_t split; - params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); - - bool is_split_k = params.divmod_splits_.divisor > 1; - uint64_t big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; - uint64_t big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; - uint64_t linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; - uint64_t k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; - - // Determine the starting k iteration computed by this stream-K work unit - uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + - (k_tiles_per_split * split); - - // Adjust the starting position and number of k iterations for "big units," which - // compute one extra iteration. If there are any big units, they will be the first - // in the linearized ID space. - auto k_tiles_in_my_split = k_tiles_per_split; - if (big_unit_cmp_lhs < big_unit_cmp_rhs) { - // Since the "big units" are the first units in the linearized ID space, each - // of the units preceding this big unit computed one extra iteration. Thus, - // we must offset our start iteration by the number of units that precede - // the current unit in the linearized ID space. - unit_iter_start += big_unit_cmp_lhs; - ++k_tiles_in_my_split; + else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. + auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); + unit_iter_start += adjustment_tiles; + k_tiles_in_my_split -= adjustment_tiles; } - else { - // Increment by one for each of the big clusters (since all big units precede this unit) - unit_iter_start += big_unit_cmp_rhs; + else if (params.ktile_start_alignment_count_ == 2 && start_tile_k_tile % 2 != 0) { + // ktile for each SM start from even number + // If start from odd number ktile within the output tile + // now start at the ktile one before my initial ktile start (take one ktile from prev sm) + // if end on odd number ktile within the output tile + // now end at ktile that one before my ktile end (give one ktile to next sm) + unit_iter_start -= 1; + k_tiles_in_my_split += 1; } + } + if (work_tile_info.k_tile_count == 0) { + // This is a new unit if (!is_split_k) { - // Adjust the unit starting position and number of tiles to avoid + // + // Adjust the unit ending position and number of tiles to avoid // computing splits of size less than min_iters_per_sk_unit_ - int unused, start_tile_k_tile; - params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); - if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { - // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another - // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to take over these K tiles. - unit_iter_start -= start_tile_k_tile; - k_tiles_in_my_split += start_tile_k_tile; - } - else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { - // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // + + // Begin by assuming that no adjustment is needed + auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; + + int unused, end_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); + + if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. - auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); - unit_iter_start += adjustment_tiles; - k_tiles_in_my_split -= adjustment_tiles; + k_tiles_in_my_split -= end_tile_k_tile; + } + else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take on these K tiles. + k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); } - else if (params.ktile_start_alignment_count == 2 && start_tile_k_tile % 2 != 0) { + else if (params.ktile_start_alignment_count_ == 2 && end_tile_k_tile % 2 != 0) { // ktile for each SM start from even number // If start from odd number ktile within the output tile // now start at the ktile one before my initial ktile start (take one ktile from prev sm) - // if end on odd number ktile within the output tile + // If end on odd number ktile within the output tile, // now end at ktile that one before my ktile end (give one ktile to next sm) - unit_iter_start -= 1; - k_tiles_in_my_split += 1; + k_tiles_in_my_split -= 1; } } - if (work_tile_info.k_tile_count == 0) { - // This is a new unit - - if (!is_split_k) { - // - // Adjust the unit ending position and number of tiles to avoid - // computing splits of size less than min_iters_per_sk_unit_ - // - - // Begin by assuming that no adjustment is needed - auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; - - int unused, end_tile_k_tile; - params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); - - if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { - // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. - k_tiles_in_my_split -= end_tile_k_tile; - } - else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { - // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to take on these K tiles. - k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); - } - else if (params.ktile_start_alignment_count == 2 && end_tile_k_tile % 2 != 0) { - // ktile for each SM start from even number - // If start from odd number ktile within the output tile - // now start at the ktile one before my initial ktile start (take one ktile from prev sm) - // If end on odd number ktile within the output tile, - // now end at ktile that one before my ktile end (give one ktile to next sm) - k_tiles_in_my_split -= 1; - } - } - - work_tile_info.k_tile_remaining = k_tiles_in_my_split; - } - - uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; - - // Find the output tile corresponding to the final k tile covered by this - // work unit. Stream-K work units will work backwards in terms of the tiles they - // are responsible computing. This is beneficial because the final (partial) - // tile computed by a stream-K block is typically the beginning of the output - // tile, while the beginning (partial) tile is typically the ending of another - // output tile. Since ending portions of an output tile must reduce across - // other work units computing portions of that output tile, it is preferable - // for them to be computed later, so as to reduce the likelihood of blocking - // on other work. - - auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); - uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; - uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; - - // Convert the output tile from the linearized space within each group to the - // overall linearized space. - output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx; - - // Bring the linearized tile ID back into the space of tiles, rather than clusters - output_tile_id *= params.get_cluster_size(); + work_tile_info.k_tile_remaining = k_tiles_in_my_split; + } + return unit_iter_start; + } - // The final linearized tile ID is in units of the cluster dimension over which we rasterize. - if (params.raster_order_ == RasterOrder::AlongN) { - output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } - else { - output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } + // Update output tile index given existing remaining k tiles of current work tile. + CUTLASS_DEVICE + static uint64_t update_output_tile_id_and_work_tile_k( + Params const& params, + WorkTileInfo& work_tile_info, + uint64_t linear_idx, + uint32_t unit_iter_start, + uint64_t cta_m_in_cluster, + uint64_t cta_n_in_cluster) { + // we divide up the linearized stream-K units into clusters and share the same K + // offsets for work within clusters. + uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); + + uint64_t unused, group_idx; + params.divmod_sk_groups_(unused, group_idx, cluster_linear_work_idx); + + uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; + + // Find the output tile corresponding to the final k tile covered by this + // work unit. Stream-K work units will work backwards in terms of the tiles they + // are responsible computing. This is beneficial because the final (partial) + // tile computed by a stream-K block is typically the beginning of the output + // tile, while the beginning (partial) tile is typically the ending of another + // output tile. Since ending portions of an output tile must reduce across + // other work units computing portions of that output tile, it is preferable + // for them to be computed later, so as to reduce the likelihood of blocking + // on other work. + + auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); + uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; + uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; + + // Convert the output tile from the linearized space within each group to the + // overall linearized space. + uint64_t output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx; + + // Bring the linearized tile ID back into the space of tiles, rather than clusters + output_tile_id *= params.get_cluster_size(); + + // The final linearized tile ID is in units of the cluster dimension over which we rasterize. + if (params.raster_order_ == RasterOrder::AlongN) { + output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; + } + else { + output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; + } + // The unit's starting k iteration in the current tile is either the starting + // iteration for the tile as a whole, or the starting k iteration for the unit + // as a whole (if the latter is greater than the former). + uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); - // The unit's starting k iteration in the current tile is either the starting - // iteration for the tile as a whole, or the starting k iteration for the unit - // as a whole (if the latter is greater than the former). - uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); + // Similarly, the unit's ending k iteration (exclusive) is either the end of + // the current tile it is assigned, or the ending iteration of the unit as a whole + // (if the latter is less than the former). + uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); - // Similarly, the unit's ending k iteration (exclusive) is either the end of - // the current tile it is assigned, or the ending iteration of the unit as a whole - // (if the latter is less than the former). - uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); + // Set the k offset to be the starting k tile for this output tile + work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); + work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; - // Set the k offset to be the starting k tile for this output tile - work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); - work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; - } + return output_tile_id; + } + // Given output tile index, update M, N, L index of current work tile info. + CUTLASS_DEVICE + static void + update_work_tile_m_n_l( + Params const& params, + uint32_t output_tile_id, + WorkTileInfo& work_tile_info, + uint64_t cta_m_in_cluster, + uint64_t cta_n_in_cluster) { uint64_t work_idx_l, remainder; params.divmod_batch_(work_idx_l, remainder, output_tile_id); @@ -907,18 +953,81 @@ class PersistentTileSchedulerSm90StreamK { work_tile_info.L_idx = static_cast(work_idx_l); } + // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info + // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining + // iterations) is used to find the next tile in the current work unit. + CUTLASS_DEVICE + static void + assign_work( + Params const& params, + uint64_t linear_idx, + dim3 block_id_in_cluster, + WorkTileInfo& work_tile_info, + uint32_t &unit_iter_start) { + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = + get_current_work_cta_m_n_in_cluster(params, linear_idx, block_id_in_cluster); + + uint64_t output_tile_id = linear_idx; + if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) { + // Separate-reduction work + auto cluster_size = params.get_cluster_size(); + // Divide up the linearized separate reduction units into clusters + uint64_t cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_)); + uint64_t cluster_tile_idx, epi_subtile_idx; + params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx); + // Bring the linearized tile ID back into the space of tiles, rather than clusters + output_tile_id = cluster_tile_idx * cluster_size; + + work_tile_info.setup_separate_reduction(epi_subtile_idx); + } + else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { + // Data-parallel work + output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; + work_tile_info.K_idx = 0; + work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; + work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; + } + else { + unit_iter_start = get_current_work_iter_start_possible_update_work_tile_k_remaining(params, linear_idx, work_tile_info); + output_tile_id = update_output_tile_id_and_work_tile_k(params, work_tile_info, + linear_idx, unit_iter_start, cta_m_in_cluster, cta_n_in_cluster); + } + update_work_tile_m_n_l(params, output_tile_id, work_tile_info, cta_m_in_cluster, cta_n_in_cluster); + } + + // The fast path to get current output tile index then update fields of work tile info + // when continuing current work tile is needed, since k tile starting index has precomputed + // in the first time fetching current work tile. + CUTLASS_DEVICE + static void + fast_assign_work( + uint32_t unit_iter_start, + Params const& params, + uint64_t linear_idx, + dim3 block_id_in_cluster, + WorkTileInfo& work_tile_info) { + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = + get_current_work_cta_m_n_in_cluster(params, linear_idx, block_id_in_cluster); + + uint64_t output_tile_id = update_output_tile_id_and_work_tile_k(params, work_tile_info, + linear_idx, unit_iter_start, cta_m_in_cluster, cta_n_in_cluster); + + update_work_tile_m_n_l(params, output_tile_id, work_tile_info, cta_m_in_cluster, cta_n_in_cluster); + } + // Returns the starting and ending peer ID of this tile CUTLASS_HOST_DEVICE static auto - tile_peer_range(Params const& params, uint32_t tile_idx, uint32_t cur_k_tile) { + tile_peer_range(Params const& params, uint32_t tile_idx, WorkTileInfo const& work_tile_info) { + uint32_t cur_k_tile = static_cast(work_tile_info.K_idx); uint32_t tile_idx_in_cluster_path = params.div_cluster_size(tile_idx); uint32_t start_k_tile = params.divmod_tiles_per_output_tile_.divisor * tile_idx_in_cluster_path; uint32_t end_k_tile = start_k_tile + params.divmod_tiles_per_output_tile_.divisor - 1; uint32_t big_unit_k_tiles = params.big_units_ * (params.divmod_k_tiles_per_sk_unit_.divisor + 1); - auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t k_tiles_per_unit) { - uint32_t unit_k_start = unit_idx * k_tiles_per_unit; - uint32_t unit_k_end = unit_k_start + k_tiles_per_unit; + auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t unit_k_start, uint32_t unit_k_end) { if (k_tile - start_k_tile < Params::min_iters_per_sk_unit_ && unit_k_end - start_k_tile < Params::min_iters_per_sk_unit_) { // k_tile is within the first min_iters_per_sk_unit_ K tiles of this output tile, @@ -943,17 +1052,22 @@ class PersistentTileSchedulerSm90StreamK { if (k_tile < big_unit_k_tiles) { // The tile is within the "big unit range" uint32_t unit_idx = params.divmod_k_tiles_per_sk_big_unit_.divide(k_tile); - return static_cast(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_big_unit_.divisor)); + uint32_t unit_k_start = unit_idx * params.divmod_k_tiles_per_sk_big_unit_.divisor; + uint32_t unit_k_end = unit_k_start + params.divmod_k_tiles_per_sk_big_unit_.divisor; + return static_cast(adjust_unit(k_tile, unit_idx, unit_k_start, unit_k_end)); } else { // The tile is after the "big unit range." Account for this by finding the "normal unit" // that it belongs to, and then offsetting by the number of big units - uint32_t unit_idx = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles) + params.big_units_; - return static_cast(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_unit_.divisor)); + uint32_t unit_idx_after_big_units = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles); + uint32_t unit_k_start = unit_idx_after_big_units * params.divmod_k_tiles_per_sk_unit_.divisor + (params.big_units_ * params.divmod_k_tiles_per_sk_big_unit_.divisor); + uint32_t unit_k_end = unit_k_start + params.divmod_k_tiles_per_sk_unit_.divisor; + uint32_t unit_idx = unit_idx_after_big_units + params.big_units_; + return static_cast(adjust_unit(k_tile, unit_idx, unit_k_start, unit_k_end)); } }; - return cute::make_tuple(find_unit(start_k_tile), find_unit(cur_k_tile), find_unit(end_k_tile)); + return cute::make_tuple(find_unit(start_k_tile), find_unit(start_k_tile + cur_k_tile), find_unit(end_k_tile)); } }; diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp index 2d9b63ffee..ba6b424324 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -37,15 +37,11 @@ #include "cutlass/arch/arch.h" #include "cutlass/detail/dependent_false.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm { -//////////////////////////////////////////////////////////////////////////////// - // // Tags for specifying tile schedulers // @@ -56,10 +52,12 @@ struct StreamKScheduler { }; struct GroupScheduler { }; // Only used for Grouped GEMMs -//////////////////////////////////////////////////////////////////////////////// - } // namespace cutlass::gemm +//////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel::detail { diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 0972731c2b..da8794bb56 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -50,6 +50,26 @@ namespace detail { //////////////////////////////////////////////////////////////////////////////// + CUTLASS_HOST_DEVICE + static uint32_t + get_max_cta_occupancy( + int max_sm_per_gpc, + GemmCoord cluster_shape, + int sm_count) { + // Provided SM count could possibly be less than the assumed maximum SMs per GPC + auto cluster_size = cluster_shape.m() * cluster_shape.n(); + int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; + int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); + int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; + + // The calculation below allows for larger grid size launch for different GPUs. + int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; + int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); + cta_per_device += max_cta_occupancy_per_residual_gpc; + + cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; + return cta_per_device; + } // // Parameters for SM90 tile schedulers // @@ -247,20 +267,7 @@ struct PersistentTileSchedulerSm90Params { * Hence, maximum SMs per GPC = 18 */ constexpr int max_sm_per_gpc = 18; - // Provided SM count could possibly be less than the assumed maximum SMs per GPC - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); - cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; - - // The calculation below allows for larger grid size launch for different GPUs. - int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; - int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); - cta_per_device += max_cta_occupancy_per_residual_gpc; - - if (sm_count < cta_per_device) { - cta_per_device = sm_count; - } + cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); if (raster_order == RasterOrder::AlongN) { launch_grid.y = possibly_truncate( cta_per_device / cluster_shape.m(), @@ -467,7 +474,7 @@ struct PersistentTileSchedulerSm90StreamKParams { static constexpr uint32_t max_sk_groups_ = 8u; // ktile start from even for each cta - uint32_t ktile_start_alignment_count { 1u }; + uint32_t ktile_start_alignment_count_ { 1u }; // Divides dividend by the cluster size CUTLASS_HOST_DEVICE @@ -519,7 +526,7 @@ struct PersistentTileSchedulerSm90StreamKParams { ReductionMode reduction_mode, DecompositionMode decomposition_mode, void* workspace, - const uint32_t epilogue_subtile = 1 + const uint32_t epilogue_subtile = 1u ) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl( problem_shape, tile_shape, cluster_shape); @@ -559,6 +566,15 @@ struct PersistentTileSchedulerSm90StreamKParams { void* workspace, const uint32_t epilogue_subtile = 1 ) { + + #if !defined(__CUDACC_RTC__) + if (hw_info.sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + #endif // !defined(__CUDACC_RTC__) + UnderlyingParams underlying_params; underlying_params.initialize( problem_blocks, @@ -568,115 +584,43 @@ struct PersistentTileSchedulerSm90StreamKParams { raster_order_option ); - auto problem_blocks_l = problem_blocks.z; - - auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); - uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; + // Set basic parameters that not affected by any heuristics in advance. + set_params_base(underlying_params, workspace); - // Reduction workspace is at the beginning of the workspace. Lock workspace follows. - void* reduction_workspace = workspace; - - if (decomposition_mode == DecompositionMode::SplitK || - (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { - // Short circuit to basic split-K decomposition - - // Don't split by more than the available number of SMs - if (splits > hw_info.sm_count) { - splits = hw_info.sm_count; - } - - // Don't split by more than the K tile iterations - // - // splits is almost certainly nonnegative here (e.g., hw_info.sm_count, - // despite being an int, is a count), so it can safely be converted to unsigned - // in the comparison to avoid a signed-unsigned comparison warning-as-error. - if (static_cast(splits) > k_tiles_per_output_tile) { - splits = k_tiles_per_output_tile; - } - - // If splits == k_tiles_per_output_tiles, there will be one k_tile per cta - // and this violate k_tile start from even requirements. Thus we need to - // reduce the number of splits. - if (ktile_start_alignment_count > 1u && - static_cast(splits) == k_tiles_per_output_tile) { - splits = k_tiles_per_output_tile / ktile_start_alignment_count; - } - - set_params_basic( - underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - splits, - k_tiles_per_output_tile, - reduction_workspace, - reduction_mode - ); - return; - } - - // Calculate the maximum number of blocks from clusters of shape cluster_shape that we - // can fit within sm_count SMs. - dim3 grid = get_grid_shape( + // Call for internal streamk heuristic to setup streamk related params + stream_k_heuristic( + underlying_params, problem_blocks, + k_tiles_per_output_tile, cluster_shape, hw_info, + splits, max_swizzle, - raster_order_option - ); - - uint64_t ctas_per_wave = grid.x * grid.y; - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. - uint32_t sk_tiles = get_num_sk_tiles( - output_tiles, - ctas_per_wave, - cluster_size, - k_tiles_per_output_tile, - decomposition_mode - ); - uint64_t dp_tiles = output_tiles - sk_tiles; - - // Calculate the number of work units covering the data-parallel and stream-K tiles. - // A "work unit" is a single index in the linearized ID space used by the scheduler. - // We distinguish it from a "block," which is typically tied to a hardware unit - // (e.g., the callers into this scheduler will be persistent thread blocks). - // A work unit can encompass multiple output tiles worth of work (as will be the - // case for stream-K blocks). - // Since splitting is not required for data-parallel tiles, only one data-parallel unit - // is needed per data-parallel tile. - uint64_t dp_units = dp_tiles; - - uint64_t ctas_per_sk_wave = ctas_per_wave; - uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); - - if (decomposition_mode == DecompositionMode::DataParallel || - (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || - sk_units == 0) { - // Short circuit to basic data-parallel decomposition - set_params_basic( - underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - /* splits = */ 1, - k_tiles_per_output_tile, - reduction_workspace, - reduction_mode - ); - return; - } - - bool do_separate_reduction = should_perform_separate_reduction( - epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave); + raster_order_option, + decomposition_mode, + reduction_mode, + epilogue_subtile + ); + } + + // max_sk_groups_ unless this extends beyond the extent of the dimension over + // which the problem is rasterized. For example, if the tiled problem shape + // (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16, + // and we rasterize along the M dimension, we choose 4 groups, rather than 8. + // If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M). + uint32_t calculate_groups( + UnderlyingParams underlying_params, + ReductionMode reduction_mode, + uint32_t problem_blocks_m, + uint32_t problem_blocks_n, + GemmCoord cluster_shape, + uint64_t cluster_size, + uint32_t sk_tiles, + uint64_t sk_cluster_tiles, + uint64_t sk_units, + uint32_t k_tiles_per_output_tile, + bool do_separate_reduction) { - // Determine the number of stream-K groups that will be used. We currently use - // max_sk_groups_ unless this extends beyond the extent of the dimension over - // which the problem is rasterized. For example, if the tiled problem shape - // (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16, - // and we rasterize along the M dimension, we choose 4 groups, rather than 8. - // If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M). uint32_t max_groups_problem; if (underlying_params.raster_order_ == RasterOrder::AlongM) { max_groups_problem = problem_blocks_m / cluster_shape.m(); @@ -691,14 +635,16 @@ struct PersistentTileSchedulerSm90StreamKParams { // number of K tiles per stream-K unit remains above min_iters_per_sk_unit_ uint32_t groups = platform::min(max_groups_problem, uint32_t(max_sk_groups_)); - - // Grouping is disabled when separate reduction is used - if (do_separate_reduction) { + // Grouping is disabled when separate reduction is used because grouping is primarily an attempt + // to improve L2 locality, and L2-locality optimizations are unnecessary when the the kernel + // is a single wave (which is the case for separate reduction). + if ( + do_separate_reduction + ) { groups = 1; } uint32_t fallback_groups = 0; - auto sk_cluster_tiles = sk_tiles / cluster_size; auto sk_cluster_units = sk_units / cluster_size; auto sk_splits_too_small = [&](uint32_t g) { @@ -737,82 +683,281 @@ struct PersistentTileSchedulerSm90StreamKParams { if (groups == 1 && fallback_groups > 0) { groups = fallback_groups; } + return groups; + } - auto sk_units_per_group = sk_units / groups; + // Stream-K kernel use below function to set stream-K feature related parameters to choose + // optimal/customized decomposition mode. + void stream_k_heuristic( + UnderlyingParams underlying_params, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + const uint32_t epilogue_subtile = 1 + ) { + uint32_t groups = 0; + uint32_t sk_tiles = 0; + uint64_t sk_units = 0; + uint64_t cluster_size = 0; + uint64_t dp_units = 0; + uint64_t k_tiles_per_group = 0; + uint64_t k_tiles_per_sk_unit = 0; + uint64_t sk_big_groups = 0; + uint32_t sk_splits = 1; + // Self calculated optimal heuristic mode + DecompositionMode heuristic_mode = + select_decomposition_mode( + groups, + sk_tiles, + sk_units, + cluster_size, + dp_units, + k_tiles_per_group, + k_tiles_per_sk_unit, + sk_big_groups, + sk_splits, + underlying_params, + problem_blocks, + k_tiles_per_output_tile, + cluster_shape, + hw_info, + splits, + max_swizzle, + raster_order_option, + decomposition_mode, + reduction_mode, + epilogue_subtile + ); - // sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as: - // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) - // Both total_tiles and sm_count are multiples of cluster size due to padding added - // prior to kernel launch. - uint64_t sk_cluster_tiles_per_group = sk_cluster_tiles / groups; - uint64_t sk_tiles_per_group = sk_cluster_tiles_per_group * cluster_size; + // Given heuristic_mode returned from the heuristic() method, set params fields. + // Here, we decouple the params that have no relation with + // decomposition mode from the params that are decided within heuristic(). + set_params( + heuristic_mode, + groups, + sk_tiles, + sk_units, + cluster_size, + dp_units, + k_tiles_per_group, + k_tiles_per_sk_unit, + sk_big_groups, + sk_splits, + underlying_params, + problem_blocks, + k_tiles_per_output_tile, + cluster_shape, + splits, + epilogue_subtile, + reduction_mode); + } - // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which - // are stream-K units within a group that process an extra K chunk. - uint64_t sk_big_groups = sk_cluster_tiles % groups; + // Return the optimal decomposition result by heuristic. + DecompositionMode select_decomposition_mode( + uint32_t &groups, + uint32_t &sk_tiles, + uint64_t &sk_units, + uint64_t &cluster_size, + uint64_t &dp_units, + uint64_t &k_tiles_per_group, + uint64_t &k_tiles_per_sk_unit, + uint64_t &sk_big_groups, + uint32_t &sk_splits, + UnderlyingParams underlying_params, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + uint32_t epilogue_subtile + ) { - uint64_t k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; + // Get block numbers in m, n and l dimensions + if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { + // Short circuit to basic split-K decomposition + uint32_t adapted_splits = adjust_split_count( + splits, hw_info.sm_count, k_tiles_per_output_tile + ); + sk_splits = adapted_splits; + return DecompositionMode::SplitK; + } + else { + // Calculate the maximum number of blocks from clusters of shape cluster_shape that we + // can fit within sm_count SMs. + // Get block numbers in m, n and l dimensions + auto problem_blocks_l = problem_blocks.z; + auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); + uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; + dim3 grid = get_grid_shape( + problem_blocks, + cluster_shape, + hw_info, + max_swizzle, + raster_order_option + ); + uint64_t ctas_per_wave = grid.x * grid.y; + cluster_size = cluster_shape.m() * cluster_shape.n(); + // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. + sk_tiles = get_num_sk_tiles( + output_tiles, + ctas_per_wave, + cluster_size, + k_tiles_per_output_tile, + decomposition_mode + ); + uint64_t dp_tiles = output_tiles - sk_tiles; + // Calculate the number of work units covering the data-parallel and stream-K tiles. + // A "work unit" is a single index in the linearized ID space used by the scheduler. + // We distinguish it from a "block," which is typically tied to a hardware unit + // (e.g., the callers into this scheduler will be persistent thread blocks). + // A work unit can encompass multiple output tiles worth of work (as will be the + // case for stream-K blocks). + // Since splitting is not required for data-parallel tiles, only one data-parallel unit + // is needed per data-parallel tile. + dp_units = dp_tiles; - // Number of k tiles computed per stream-K unit - uint64_t k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; + uint64_t ctas_per_sk_wave = ctas_per_wave; + sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); - uint32_t reduction_units = 0; + if (decomposition_mode == DecompositionMode::DataParallel || + (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || + sk_units == 0) { + // Short circuit to basic data-parallel decomposition + return DecompositionMode::DataParallel; + } + else { + bool do_separate_reduction = should_perform_separate_reduction( + epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave); + + uint64_t sk_cluster_tiles = sk_tiles / cluster_size; + + groups = calculate_groups(underlying_params, reduction_mode, problem_blocks_m, problem_blocks_n, cluster_shape, + cluster_size, sk_tiles, sk_cluster_tiles, sk_units, k_tiles_per_output_tile, do_separate_reduction); + + auto sk_units_per_group = sk_units / groups; + + // sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as: + // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) + // Both total_tiles and sm_count are multiples of cluster size due to padding added + // prior to kernel launch. + uint64_t sk_cluster_tiles_per_group = sk_cluster_tiles / groups; + uint64_t sk_tiles_per_group = sk_cluster_tiles_per_group * cluster_size; + + // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which + // are stream-K units within a group that process an extra K chunk. + sk_big_groups = sk_cluster_tiles % groups; + + k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; + + // Number of k tiles computed per stream-K unit + k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; + + DecompositionMode heuristic_mode; + if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { + // If the number of stream-K units is a multiple of the number of stream-K tiles, then + // the problem can leverage a basic split-K decomposition for the stream-K tiles. + // This case happens when separate reduction is disable. + sk_splits = static_cast(sk_units / sk_tiles); + heuristic_mode = DecompositionMode::SplitK; + } + else { + // Rest scenario is streamk + heuristic_mode = DecompositionMode::StreamK; + } + // Refresh heuristic_mode using analytical model before choosing streamk/separate_reduction decomposition, + // ideally it's to get the final decomposition more accuracy. Comment it as it is place holder at this moment. + #if 0 + uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); + analytical_model(heuristic_mode, k_tiles_per_output_tile, k_tiles_per_sk_unit, + sk_splits, epilogue_subtile, total_waves); + #endif + return heuristic_mode; + } + } + } - // Use separate reduction when we have less than one wave of output tiles (dp_tiles == 0) - // and when each tile will be operated on by at least two stream-K units (sk_units > 2 * sk_tiles) - if (do_separate_reduction) { - // Each reduction unit will reduce the partials of an epilogue subtile for - // a given output tile and compute the epilogue. Thus, there are as many reduction - // units as there are epilogue subtiles. - reduction_units = sk_tiles * epilogue_subtile; + // Given decomposition mode output from heuristic, set all feilds of params. + void set_params( + DecompositionMode heuristic_mode, + uint32_t groups, + uint32_t sk_tiles, + uint64_t sk_units, + uint64_t cluster_size, + uint64_t dp_units, + uint64_t k_tiles_per_group, + uint64_t k_tiles_per_sk_unit, + uint64_t sk_big_groups, + uint32_t sk_splits, + UnderlyingParams underlying_params, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord cluster_shape, + uint32_t splits, + uint32_t epilogue_subtile, + ReductionMode reduction_mode) { + // The highest priority when customers set as splitk mode, may set + // with a adpated splits value rather than the original splits + // even it does not make sense + if (splits > 1 && heuristic_mode == DecompositionMode::SplitK) { + set_params_basic( + underlying_params, + problem_blocks, + cluster_shape, + sk_splits, // split-k set by customers + k_tiles_per_output_tile, + reduction_mode + ); } - else if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { - // If the number of stream-K units is a multiple of the number of stream-K tiles, then - // the problem can leverage a basic split-K decomposition for the stream-K tiles. - // This case happens when separate reduction is disable. - uint32_t sk_splits = static_cast(sk_units / sk_tiles); + else if (heuristic_mode == DecompositionMode::DataParallel) { set_params_basic( underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - sk_splits, + problem_blocks, + cluster_shape, + 1, // fast path to fall back to the mode without any split scheme k_tiles_per_output_tile, - reduction_workspace, reduction_mode ); - return; } - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; - divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; - divmod_batch_ = underlying_params.divmod_batch_; - divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); - divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; - divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); - divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); - - // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. - // This setting ensures that the use of this divmod for stream-K decompositions - // is essentially a no-op. - divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); - divmod_splits_ = FastDivmod(1); - log_swizzle_size_ = underlying_params.log_swizzle_size_; - units_per_problem_ = static_cast(dp_units + sk_units); - raster_order_ = underlying_params.raster_order_; - - // Assign big_units_ assuming that group count == 1. This is unused by stream-K - // when group count > 1. - big_units_ = static_cast(k_tiles_per_group % k_tiles_per_sk_unit); - - big_groups_ = static_cast(sk_big_groups); - reduction_workspace_ = reduction_workspace; - sk_tiles_ = sk_tiles; - sk_units_ = static_cast(sk_units); - divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); - divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); - reduction_mode_ = reduction_mode; - divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); - separate_reduction_units_ = reduction_units; + else if (heuristic_mode == DecompositionMode::SplitK) { + set_params_basic( + underlying_params, + problem_blocks, + cluster_shape, + sk_splits, // splits calculated by heuristic + k_tiles_per_output_tile, + reduction_mode + ); + } + else { + // streamk + set_params_stream_k( + underlying_params, + k_tiles_per_output_tile, + groups, + sk_tiles, + sk_units, + cluster_size, + dp_units, + k_tiles_per_group, + k_tiles_per_sk_unit, + sk_big_groups, + reduction_mode, + 1, /*epilogue_subtile*/ + 0 /*reduction_units*/ + ); + } } // Given the inputs, computes the physical grid we should launch. @@ -897,7 +1042,6 @@ struct PersistentTileSchedulerSm90StreamKParams { // or if there is no work to be split. return 0; } - // // The final wave is not full. Perform some stream-K work. // @@ -971,11 +1115,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t accumulator_bits, uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1) { + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1) { auto log_swizzle_size = UnderlyingParams::get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle); problem_blocks.x = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); @@ -989,12 +1135,6 @@ struct PersistentTileSchedulerSm90StreamKParams { barrier_workspace_size = 0; reduction_workspace_size = 0; } - else if (splits > 1 && - (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic)) { - // Basic split-K variant requires workspace for all output tiles - barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); - } else { KernelHardwareInfo new_hw_info; new_hw_info.device_id = hw_info.device_id; @@ -1025,20 +1165,42 @@ struct PersistentTileSchedulerSm90StreamKParams { uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); uint64_t dp_tiles = output_tiles - sk_tiles; - uint64_t reduction_tiles = sk_tiles; - if (should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave)) { - // In separate reduction, each peer writes to its own location in scratch space. - // Thus, for separate reduction, we need as many reduction tiles per output tile - // as there are the maximum number of peers that can collaborate on an output tile. - reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles); + if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { + splits = adjust_split_count( + splits, new_hw_info.sm_count, k_tiles_per_output_tile + ); } - // Though separate reduction requires a larger reduction workspace, only one barrier - // is needed per output tile. Each peer will increment the barrier by one once the peer has - // written its accumulator to scratch space. The separate reduction unit will only begin - // performing the reduction when the barrier has reached the number of peers for the output tile. - barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); + bool split_k_required = splits > 1 && (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic); + bool split_k_selected = decomposition_mode == DecompositionMode::Heuristic && + sk_units > sk_tiles && + sk_tiles != 0 && + sk_units % sk_tiles == 0; + + if (split_k_required || split_k_selected) { + // Basic split-K variant requires workspace for all output tiles + barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); + reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); + } + else { + uint64_t reduction_tiles = sk_tiles; + if ( + should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave) + ) { + // In separate reduction, each peer writes to its own location in scratch space. + // Thus, for separate reduction, we need as many reduction tiles per output tile + // as there are the maximum number of peers that can collaborate on an output tile. + reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles); + } + + // Though separate reduction requires a larger reduction workspace, only one barrier + // is needed per output tile. Each peer will increment the barrier by one once the peer has + // written its accumulator to scratch space. The separate reduction unit will only begin + // performing the reduction when the barrier has reached the number of peers for the output tile. + barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits); + reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); + } } } #endif // !defined(__CUDACC_RTC__) @@ -1063,11 +1225,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile, - uint32_t num_accumulator_mtxs) { + uint32_t num_accumulator_mtxs, + uint32_t ktile_start_alignment_count = 1) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -1082,11 +1246,13 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, - num_accumulator_mtxs + num_accumulator_mtxs, + ktile_start_alignment_count ); } @@ -1104,11 +1270,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1) { + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1) { size_t barrier_workspace_size = 0; size_t reduction_workspace_size = 0; @@ -1126,11 +1294,13 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, - num_accumulator_mtxs + num_accumulator_mtxs, + ktile_start_alignment_count ); #endif @@ -1151,11 +1321,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile, - CudaHostAdapter* cuda_adapter = nullptr) { + CudaHostAdapter* cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -1172,12 +1344,14 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, 1, - cuda_adapter + cuda_adapter, + ktile_start_alignment_count ); } @@ -1197,12 +1371,14 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile = 1, uint32_t num_accumulator_mtxs = 1, - CudaHostAdapter* cuda_adapter = nullptr) { + CudaHostAdapter* cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1) { #if !defined(__CUDACC_RTC__) uint64_t barrier_workspace_size = 0; @@ -1220,11 +1396,13 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, - num_accumulator_mtxs + num_accumulator_mtxs, + ktile_start_alignment_count ); if (barrier_workspace_size > 0) { @@ -1242,31 +1420,41 @@ struct PersistentTileSchedulerSm90StreamKParams { return Status::kSuccess; } + // Set params for basic parameters, which will not affected by different decompositions. + void + set_params_base(UnderlyingParams const& underlying_params, void* reduction_workspace) { + divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; + divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; + divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; + log_swizzle_size_ = underlying_params.log_swizzle_size_; + raster_order_ = underlying_params.raster_order_; + reduction_workspace_ = reduction_workspace; + } + void set_params_basic( UnderlyingParams const& underlying_params, - uint32_t blocks_m, - uint32_t blocks_n, - uint32_t blocks_l, + dim3 problem_blocks, + GemmCoord cluster_shape, uint32_t splits, uint32_t k_tiles_per_output_tile, - void* reduction_workspace, ReductionMode reduction_mode) { - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; - divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; + auto blocks_l = problem_blocks.z; + auto blocks_m = round_up(problem_blocks.x, + (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); + auto blocks_n = round_up(problem_blocks.y, + (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); + divmod_batch_ = FastDivmodU64(blocks_m * blocks_n); divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); divmod_sk_groups_ = FastDivmodU64(1u); - auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * underlying_params.divmod_cluster_shape_minor_.divisor; + auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * + underlying_params.divmod_cluster_shape_minor_.divisor; divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size); divmod_splits_ = FastDivmod(splits); - divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; - log_swizzle_size_ = underlying_params.log_swizzle_size_; units_per_problem_ = blocks_m * blocks_n * blocks_l; - raster_order_ = underlying_params.raster_order_; big_units_ = k_tiles_per_output_tile % splits; - reduction_workspace_ = reduction_workspace; reduction_mode_ = reduction_mode; divmod_k_tiles_per_sk_unit_ = FastDivmod(k_tiles_per_output_tile / splits); divmod_k_tiles_per_sk_big_unit_ = FastDivmod(k_tiles_per_output_tile / splits + 1); @@ -1278,6 +1466,55 @@ struct PersistentTileSchedulerSm90StreamKParams { separate_reduction_units_ = 0; } + // Set params for streamk(streamk, separate-reduction included) decomposition. + void + set_params_stream_k( + UnderlyingParams const& underlying_params, + uint32_t k_tiles_per_output_tile, + uint32_t groups, + uint32_t sk_tiles, + uint64_t sk_units, + uint64_t cluster_size, + uint64_t dp_units, + uint64_t k_tiles_per_group, + uint64_t k_tiles_per_sk_unit, + uint64_t sk_big_groups, + ReductionMode reduction_mode, + uint32_t epilogue_subtile, + uint32_t reduction_units) { + // stream-k and separate-reduction decompostions + divmod_batch_ = underlying_params.divmod_batch_; + divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); + divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); + divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); + + // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. + // This setting ensures that the use of this divmod for stream-K decompositions + // is essentially a no-op. + divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); + divmod_splits_ = FastDivmod(1); + units_per_problem_ = static_cast(dp_units + sk_units); + + // Assign big_units_ assuming that group count == 1. This is unused by stream-K + // when group count > 1. + auto big_units_in_ctas = k_tiles_per_group % sk_units; + + // Store big_units in terms of clusters. big_units_in_ctas is guaranteed to be divisible + // by cluster_size because both k_tiles_per_group and k_tiles_per_sk_unit must be a multiple + // of cluster_size. + auto big_units_in_clusters = big_units_in_ctas / cluster_size; + big_units_ = static_cast(big_units_in_clusters); + + big_groups_ = static_cast(sk_big_groups); + sk_tiles_ = sk_tiles; + sk_units_ = static_cast(sk_units); + divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); + divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); + reduction_mode_ = reduction_mode; + divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); + separate_reduction_units_ = reduction_units; + } + private: // Round up number of bytes to the nearest multiple of L2 cache line alignment CUTLASS_HOST_DEVICE @@ -1286,8 +1523,31 @@ struct PersistentTileSchedulerSm90StreamKParams { constexpr size_t L2CacheLineSizeBytes = 128u; return (bytes + L2CacheLineSizeBytes - 1) / L2CacheLineSizeBytes * L2CacheLineSizeBytes; } + + CUTLASS_HOST_DEVICE + static int adjust_split_count( + int splits, + int sm_count, + uint32_t k_tiles_per_output_tile + ) { + // Don't split by more than the available number of SMs + if (splits > sm_count) { + splits = sm_count; + } + + // Don't split by more than the K tile iterations + if (static_cast(splits) > k_tiles_per_output_tile) { + splits = k_tiles_per_output_tile; + } + + // If k_tiles_per_output_tiles / splits == 1, there will be one k_tile per cta + // and this violate k_tile start from even requirements. Thus we need to + // reduce the number of splits. + return splits; + } }; + //////////////////////////////////////////////////////////////////////////////// // Parameters for SM90 persistent group scheduler (only used for Grouped Gemms) @@ -1453,18 +1713,7 @@ struct PersistentTileSchedulerSm90GroupParams { // GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU // Hence, maximum SMs per GPC = 18 constexpr int max_sm_per_gpc = 18; - // Provided SM count could possibly be less than the assumed maximum SMs per GPC - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); - int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; - - // The calculation below allows for larger grid size launch for different GPUs. - int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; - int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); - cta_per_device += max_cta_occupancy_per_residual_gpc; - - cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; + int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); if (raster_order == RasterOrder::AlongN) { launch_grid.y = possibly_truncate( diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index c778832bf8..4c70bcf3fb 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -147,7 +147,7 @@ struct MmaGeneric { CUTLASS_PRAGMA_UNROLL for (int k = 0; k < Shape::kK; ++k) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) - if (kMultipleOf2 && kAllFp32) { + if constexpr (kMultipleOf2 && kAllFp32) { //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN; n+=2) { @@ -396,34 +396,36 @@ struct MmaGeneric< CUTLASS_PRAGMA_UNROLL for (int k = 0; k < Shape::kK; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; ++n) { - + { CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { + for (int n = 0; n < Shape::kN; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { - int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; + int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; + + MatrixCoord mn(m_serpentine, n); + MatrixCoord mk(m_serpentine, k); + MatrixCoord kn(k, n); - MatrixCoord mn(m_serpentine, n); - MatrixCoord mk(m_serpentine, k); - MatrixCoord kn(k, n); + Array d; + Array a; + Array b; - Array d; - Array a; - Array b; + d[0] = d_ref.at(mn); + a[0] = a_ref.at(mk); + b[0] = b_ref.at(kn); - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); + if ((m == 0 && n) || m == Shape::kM - 1) { + mma_corner(d, a, b, d); + } + else { + mma_column(d, a, b, d); + } - if ((m == 0 && n) || m == Shape::kM - 1) { - mma_corner(d, a, b, d); - } - else { - mma_column(d, a, b, d); + d_ref.at(mn) = d[0]; } - - d_ref.at(mn) = d[0]; } } } diff --git a/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/include/cutlass/gemm/threadblock/ell_mma_multistage.h index 27f410ccd1..17cc9dae85 100644 --- a/include/cutlass/gemm/threadblock/ell_mma_multistage.h +++ b/include/cutlass/gemm/threadblock/ell_mma_multistage.h @@ -243,12 +243,12 @@ class EllMmaMultistage : if (is_offset_constant){ auto ell_offset = ell_iter.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorA::Element) / kSrcBytes; } else { int k_offset = iterator_A.get_k(); auto ell_offset = ell_iter.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; + gmem_ptr += (ell_offset * sizeof(typename IteratorA::Element)) / kSrcBytes; } } @@ -287,12 +287,12 @@ class EllMmaMultistage : if (is_offset_constant){ auto ell_offset = ell_iter.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorB::Element) / kSrcBytes; } else { int k_offset = iterator_B.get_k(); auto ell_offset = ell_iter.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; + gmem_ptr += ( ell_offset * sizeof(typename IteratorB::Element)) / kSrcBytes; } } @@ -359,12 +359,12 @@ class EllMmaMultistage : if (is_offset_constant){ auto ell_offset = ell_iterator.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorA::Element) / kSrcBytes; } else { int k_offset = iterator_A.get_k(); auto ell_offset = ell_iterator.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; + gmem_ptr += (ell_offset * sizeof(typename IteratorA::Element)) / kSrcBytes; } } @@ -401,12 +401,12 @@ class EllMmaMultistage : if (is_offset_constant){ auto ell_offset = ell_iterator.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorB::Element) / kSrcBytes; } else { int k_offset = iterator_B.get_k(); auto ell_offset = ell_iterator.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; + gmem_ptr += ( ell_offset * sizeof(typename IteratorB::Element)) / kSrcBytes; } } diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index b84d322dbb..27a50fd290 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -93,7 +93,7 @@ struct integer_subbyte { [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; assert(value >= lower_bound); - assert(value < upper_bound); + assert(value <= upper_bound); } else { [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; @@ -112,7 +112,7 @@ struct integer_subbyte { [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; assert(value >= lower_bound); - assert(value < upper_bound); + assert(value <= upper_bound); } else { [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; @@ -120,6 +120,10 @@ struct integer_subbyte { } } + CUTLASS_HOST_DEVICE explicit + integer_subbyte(uint8_t value) + : integer_subbyte(static_cast(value)) {} + // Convert to the "external" integer type (int or unsigned) CUTLASS_HOST_DEVICE operator xint_t() const { diff --git a/include/cutlass/kernel_launch.h b/include/cutlass/kernel_launch.h index ca3380a2a1..4cd087a3b3 100644 --- a/include/cutlass/kernel_launch.h +++ b/include/cutlass/kernel_launch.h @@ -37,6 +37,7 @@ #include #include "cutlass/cutlass.h" #include "cutlass/trace.h" +#include "cutlass/device_kernel.h" // cutlass::device_kernel namespace cutlass { diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index 912eb2c8cf..13e5ef222f 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -38,11 +38,8 @@ computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_} as new addresses after permute op. */ #pragma once -#if defined(__CUDACC_RTC__) + #include -#else -#include "assert.h" -#endif #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 8374fe31d0..d296f1d04b 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -39,11 +39,9 @@ defined in cutlass/tensor_ref.h. */ #pragma once -#if defined(__CUDACC_RTC__) + #include -#else -#include "assert.h" -#endif + #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/include/cutlass/layout/tensor_op_multiplicand_sm70.h index 4691b98298..b260942a73 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm70.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm70.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/coord.h" #include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_coord.h" // cutlass::MatrixCoord ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 9c4fb395af..b62a90ccac 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -95,7 +95,6 @@ struct NumericConverter { // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) template <> struct NumericConverter { @@ -103,50 +102,17 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if __CUDA_ARCH__ return __float2int_rn(s); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = int32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_DEVICE - static result_type convert(source_type const & s) { - - return __float2int_rz(s); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#elif !defined(__CUDACC_RTC__) - -template <> -struct NumericConverter { - - using result_type = int32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - static result_type convert(source_type const & s) { + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TONEAREST); - return (result_type)std::nearbyint(s); + return static_cast(std::nearbyint(s)); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -159,16 +125,21 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if __CUDA_ARCH__ + return __float2int_rz(s); + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TOWARDZERO); return (result_type)std::nearbyint(s); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } }; -#endif ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -176,7 +147,6 @@ struct NumericConverter { // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) template <> struct NumericConverter { @@ -184,109 +154,24 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - CUTLASS_DEVICE - static result_type convert(source_type const & s) { - - int32_t intermediate; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - - return static_cast(intermediate); - } - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_DEVICE static result_type convert(source_type const & s) { - - int32_t intermediate; - asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - - return static_cast(intermediate); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = uint8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - CUTLASS_DEVICE - static result_type convert(source_type const & s) { - - int32_t intermediate; - asm volatile("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - - return static_cast(intermediate); - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = uint8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_DEVICE - static result_type convert(source_type const & s) { - + #if defined(__CUDA_ARCH__) int32_t intermediate; - asm volatile("cvt.rzi.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); return static_cast(intermediate); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#elif !defined(__CUDACC_RTC__) - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - static result_type convert(source_type const & s) { + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TONEAREST); int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -299,19 +184,24 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if defined(__CUDA_ARCH__) + int32_t intermediate; + asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); + return static_cast(intermediate); + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TOWARDZERO); int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -324,19 +214,24 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if defined(__CUDA_ARCH__) + int32_t intermediate; + asm volatile("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); + return static_cast(intermediate); + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TONEAREST); int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -349,26 +244,29 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if __CUDA_ARCH__ + int32_t intermediate; + asm volatile("cvt.rzi.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); + return static_cast(intermediate); + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TOWARDZERO); int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } }; -#endif - ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for float => integer_subbyte @@ -3281,88 +3179,88 @@ namespace detail { ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - unsigned const& storage = reinterpret_cast(source); - unsigned out[2]; - - asm volatile( - "{\n" - " .reg .u32 tmp0, tmp1, tmp2;\n" - " shl.b32 tmp0, %2, 4;\n" // tmp0 = x1x2x3x4x5x6x7__ - " and.b32 tmp0, tmp0, 0xf0f0f0f0;\n" // tmp0 = x1__x3__x5__x7__ - " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s1s3s5s7 - " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s1__s3__s5__s7__ - " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x1__x3__x5__x7 - " or.b32 tmp2, tmp0, tmp1;\n" // tmp2 = y1y3y5y7 - " and.b32 tmp0, %2, 0xf0f0f0f0;\n" // tmp0 = x0__x2__x4__x6__ - " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s0s2s4s6 - " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s0__s2__s4__s6__ - " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x0__x2__x4__x6 - " or.b32 tmp0, tmp0, tmp1;\n" // tmp0 = y0y2y4y6 - " prmt.b32 %0, tmp2, tmp0, 0x5140;\n" // %0 = y0y1y2y3 - " prmt.b32 %1, tmp2, tmp0, 0x7362;\n" // %1 = y4y5y6y7 - "}\n" - : "=r"(out[0]), "=r"(out[1]) - : "r"(storage)); - - return reinterpret_cast(out); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - /// Partial specialization for Array <= Array template < int N, FloatRoundStyle Round > struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); + + static_assert(N % 8 == 0, "N must be a multiple of 8"); using result_type = Array; using source_type = Array; static FloatRoundStyle const round_style = Round; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { + + #if defined(__CUDA_ARCH__) - NumericArrayConverter convert_vector_; - - result_type result; + if constexpr ( N == 8 ) { + + unsigned const& storage = reinterpret_cast(source); + unsigned out[2]; - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); + asm volatile( + "{\n" + " .reg .u32 tmp0, tmp1, tmp2;\n" + " shl.b32 tmp0, %2, 4;\n" // tmp0 = x1x2x3x4x5x6x7__ + " and.b32 tmp0, tmp0, 0xf0f0f0f0;\n" // tmp0 = x1__x3__x5__x7__ + " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s1s3s5s7 + " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s1__s3__s5__s7__ + " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x1__x3__x5__x7 + " or.b32 tmp2, tmp0, tmp1;\n" // tmp2 = y1y3y5y7 + " and.b32 tmp0, %2, 0xf0f0f0f0;\n" // tmp0 = x0__x2__x4__x6__ + " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s0s2s4s6 + " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s0__s2__s4__s6__ + " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x0__x2__x4__x6 + " or.b32 tmp0, tmp0, tmp1;\n" // tmp0 = y0y2y4y6 + " prmt.b32 %0, tmp2, tmp0, 0x5140;\n" // %0 = y0y1y2y3 + " prmt.b32 %1, tmp2, tmp0, 0x7362;\n" // %1 = y4y5y6y7 + "}\n" + : "=r"(out[0]), "=r"(out[1]) + : "r"(storage)); + return reinterpret_cast(out); + + } else { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + #else + + result_type result; + NumericConverter convert_; + CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); + for (int i = 0; i < N; ++i) { + result[i] = convert_(source[i]); } - + return result; + + #endif // __CUDA_ARCH__ } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } }; -#endif // defined(__CUDA_ARCH__) /// Partial specialization for Array <= Array template diff --git a/include/cutlass/numeric_size.h b/include/cutlass/numeric_size.h index 4ff83bab88..98fd77c394 100644 --- a/include/cutlass/numeric_size.h +++ b/include/cutlass/numeric_size.h @@ -68,6 +68,15 @@ bits_to_bytes(T bits) { return (R(bits) + R(7)) / R(8); } +/// Returns the number of bits required to hold a specified number of bytes +template +CUTLASS_HOST_DEVICE +constexpr +R +bytes_to_bits(T bytes) { + return R(bytes) * R(8); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index 5519fbe7c9..ca37896bca 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -34,8 +34,6 @@ */ #pragma once -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" #include "cutlass/numeric_size.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 96bb8db745..b1d04f5117 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -51,6 +51,65 @@ namespace cutlass { using namespace cute; +namespace detail { + +// Helper function for DEBUG checks +template +CUTLASS_DEVICE +bool pipeline_is_producer(ThreadCategory role) { + return (role == ThreadCategory::Producer || role == ThreadCategory::ProducerConsumer); +} + +template +CUTLASS_DEVICE +void pipeline_check_is_producer(ThreadCategory role) { + #ifndef NDEBUG + if (!pipeline_is_producer(role)) { + asm volatile ("brkpt;\n" ::); + } + #endif +} + +template +CUTLASS_DEVICE +bool pipeline_is_consumer(ThreadCategory role) { + return (role == ThreadCategory::Consumer || role == ThreadCategory::ProducerConsumer); +} + +template +CUTLASS_DEVICE +void pipeline_check_is_consumer(ThreadCategory role) { + #ifndef NDEBUG + if (!pipeline_is_consumer(role)) { + asm volatile ("brkpt;\n" ::); + } + #endif +} + +CUTLASS_DEVICE +cute::tuple spread_arrivals_to_warp(int thread_idx_in_warp) { + constexpr uint32_t MaxClusterSize = 16; + bool is_signaling_thread = (thread_idx_in_warp % (32 / MaxClusterSize)) == 0; + auto layout = Layout,Stride<_4, _1>>{}; + uint32_t thread_row = thread_idx_in_warp / 8; + uint32_t thread_col = (thread_idx_in_warp % 8) / 2; + uint32_t dst_blockid = layout(thread_row, thread_col); + return cute::make_tuple(is_signaling_thread, dst_blockid); +} + +CUTLASS_DEVICE +cute::tuple spread_arrivals_to_warpgroup(int thread_idx_in_warpgroup, int warp_idx) { + constexpr uint32_t MaxClusterSize = 16; + bool is_signaling_thread = (thread_idx_in_warpgroup % (NumThreadsPerWarpGroup / MaxClusterSize)) == 0; + auto layout = cute::composition(Swizzle<2,0,-2>{}, + Layout,Stride<_4,_1>>{}); + uint32_t thread_row = warp_idx % 4; + uint32_t thread_col = (thread_idx_in_warpgroup / 8) % 4; + uint32_t dst_blockid = layout(thread_row, thread_col); + return cute::make_tuple(is_signaling_thread, dst_blockid); +} +} // namespace detail + enum class BarrierStatus : uint32_t { WaitAgain = 0u, WaitDone = 1u, @@ -210,7 +269,7 @@ PipelineState make_producer_start_state() { // Currently, it is optional to elect a leader for the Consumers template class PipelineTmaAsync { -public : +public: using FullBarrier = cutlass::arch::ClusterTransactionBarrier; using EmptyBarrier = cutlass::arch::ClusterBarrier; using ProducerBarrierType = FullBarrier::ValueType; @@ -237,68 +296,92 @@ public : uint32_t num_consumers = 0; }; - // Constructor - template + template + static CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape) + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 + if (cute::size(cluster_shape) > 1) { + multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * + num_consumer_warpgroups_per_cluster; + } + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) : params_(params) , full_barrier_ptr_(&storage.full_barrier_[0]) , empty_barrier_ptr_(&storage.empty_barrier_[0]) { int warp_idx = canonical_warp_idx_sync(); + int thread_idx = threadIdx.x; int lane_predicate = cute::elect_one_sync(); - if (warp_idx == 0 && lane_predicate == 1) { - // Barrier FULL init - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(1); + static_assert(cute::is_same_v || cute::is_same_v); + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + if constexpr (cute::is_same_v) { + // Logic to optimally schedule Empty Arrives + // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) + dim3 block_id = cute::block_id_in_cluster(); + auto cluster_size = cute::size(cluster_shape); + + if (cluster_size == 1) { + is_signaling_thread_ = true; + dst_blockid_ = 0; } - uint32_t const num_consumer_warpgroups_per_cluster = params_.num_consumers / NumThreadsPerWarpGroup; - uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * - num_consumer_warpgroups_per_cluster; - // Barrier EMPTY init - for (int i = 0; i < Stages; ++i) { - empty_barrier_ptr_[i].init(multicast_consumer_arrival_count); + else { + // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) + if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { + auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warpgroup(thread_idx % NumThreadsPerWarpGroup, warp_idx); + is_signaling_thread_ = is_signaling_thread; + dst_blockid_ = dst_blockid; + } + else if (params_.num_consumers == 32) { + auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warp(thread_idx % 32); + is_signaling_thread_ = is_signaling_thread; + dst_blockid_ = dst_blockid; + } + else { + is_signaling_thread_ = 0; + #ifndef NDEBUG + asm volatile ("brkpt;\n" ::); + #endif + } + + // STEP 2: Find if this dst block-id needs an arrival for this problem + is_signaling_thread_ &= dst_blockid_ < cluster_size; + is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); } } - cutlass::arch::fence_barrier_init(); - - // Logic to optimally schedule Empty Arrives - // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) - dim3 block_id = cute::block_id_in_cluster(); - auto cluster_size = cute::size(cluster_shape); - static constexpr int MaxClusterSize = 16; - - // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) - if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { - int thread_idx = threadIdx.x % NumThreadsPerWarpGroup; - is_signalling_thread_ = (thread_idx % (NumThreadsPerWarpGroup / MaxClusterSize)) == 0; - auto layout = cute::composition(Swizzle<2,0,-2>{}, - Layout,Stride<_4,_1>>{}); - uint32_t thread_row = warp_idx % 4; - uint32_t thread_col = (thread_idx / 8) % 4; - dst_blockid_ = layout(thread_row, thread_col); - } - else if (params_.num_consumers == 32) { - int thread_idx = threadIdx.x % 32; - is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; - auto layout = Layout,Stride<_4, _1>>{}; - uint32_t thread_row = thread_idx / 8; - uint32_t thread_col = (thread_idx % 8) / 2; - dst_blockid_ = layout(thread_row, thread_col); - } - else { - is_signalling_thread_ = 0; - #ifndef NDEBUG - asm volatile ("brkpt;\n" ::); - #endif - } - - // STEP 2: Find if this dst block-id needs an arrival for this problem - is_signalling_thread_ &= dst_blockid_ < cluster_size; - is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); } + // Constructor + template + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape) + : PipelineTmaAsync(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } + + template + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) + : PipelineTmaAsync(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } + template CUTLASS_DEVICE bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { @@ -347,6 +430,7 @@ public : // This should be called once before kernel exits. CUTLASS_DEVICE void producer_tail(PipelineState state) { + detail::pipeline_check_is_producer(params_.role); for (int count = 0; count < Stages; ++count) { empty_barrier_ptr_[state.index()].wait(state.phase()); ++state; @@ -386,15 +470,16 @@ public : consumer_release(state.index()); } -private : +private: uint32_t dst_blockid_ = 0; - uint32_t is_signalling_thread_ = 0; + uint32_t is_signaling_thread_ = 0; FullBarrier *full_barrier_ptr_ = nullptr; EmptyBarrier *empty_barrier_ptr_ = nullptr; Params params_; CUTLASS_DEVICE ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -404,6 +489,7 @@ private : CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); if (barrier_token != BarrierStatus::WaitDone) { empty_barrier_ptr_[stage].wait(phase); } @@ -454,6 +540,7 @@ private : CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -463,6 +550,7 @@ private : CUTLASS_DEVICE ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -473,12 +561,14 @@ private : // Wait for producer to commit transactions (done by TMA) CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase) { + detail::pipeline_check_is_consumer(params_.role); full_barrier_ptr_[stage].wait(phase); } // Wait for producer to commit transactions (done by TMA) CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { full_barrier_ptr_[stage].wait(phase); } @@ -488,7 +578,8 @@ private : // Ensures all blocks in the Same Row and Column get notifed. CUTLASS_DEVICE void consumer_release(uint32_t stage, uint32_t skip = false) { - empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); + detail::pipeline_check_is_consumer(params_.role); + empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signaling_thread_ & (!skip)); #ifndef NDEBUG if (params_.role == ThreadCategory::Producer || params_.role == ThreadCategory::NonParticipant) { asm volatile ("brkpt;\n" ::); @@ -625,7 +716,7 @@ class PipelineTmaStore< /* Stages_ = */ 0, /* UnacquiredStages = Stages_ - 1 = * /////////////////////////////////////////////////////////////////////////////////////////////////// template class PipelineTransactionAsync { -public : +public: using FullBarrier = cutlass::arch::ClusterTransactionBarrier; using EmptyBarrier = cutlass::arch::ClusterBarrier; using ProducerBarrierType = FullBarrier::ValueType; @@ -653,26 +744,45 @@ public : uint32_t dst_blockid = cute::block_rank_in_cluster(); }; + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params const& params) { + FullBarrier *full_barrier_ptr = storage.full_barrier_.data(); + EmptyBarrier *empty_barrier_ptr = storage.empty_barrier_.data(); + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + full_barrier_ptr, empty_barrier_ptr, params.producer_arv_count, params.consumer_arv_count); + } + } + // Constructor + template CUTLASS_DEVICE - PipelineTransactionAsync(SharedStorage& storage, Params const& params) + PipelineTransactionAsync(SharedStorage& storage, Params const& params, InitBarriers = cute::true_type{}) : params_(params) , full_barrier_ptr_(storage.full_barrier_.data()) , empty_barrier_ptr_(storage.empty_barrier_.data()) { + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); - // Barrier FULL, EMPTY init - // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate) { - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(params.producer_arv_count); - empty_barrier_ptr_[i].init(params.consumer_arv_count); - } + static_assert(cute::is_same_v || cute::is_same_v); + + if constexpr (cute::is_same_v) { + init_barriers(storage, params); } - cutlass::arch::fence_barrier_init(); + } + // Constructor + CUTLASS_DEVICE + PipelineTransactionAsync(SharedStorage& storage, Params const& params) : + PipelineTransactionAsync(storage, params, cute::true_type{}) { } + //////////////////// // Producer APIs //////////////////// @@ -758,6 +868,7 @@ public : CUTLASS_DEVICE ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -767,6 +878,7 @@ public : CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { empty_barrier_ptr_[stage].wait(phase); } @@ -775,11 +887,13 @@ public : // Perform an expect-tx operation on the stage's full barrier. Must be called by 1 thread CUTLASS_DEVICE void producer_expect_transaction(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); full_barrier_ptr_[stage].expect_transaction(params_.transaction_bytes); } CUTLASS_DEVICE void producer_commit(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); full_barrier_ptr_[stage].arrive(params_.dst_blockid); } @@ -790,6 +904,7 @@ public : CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -799,6 +914,7 @@ public : CUTLASS_DEVICE ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -808,6 +924,7 @@ public : CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { full_barrier_ptr_[stage].wait(phase); } @@ -815,6 +932,7 @@ public : CUTLASS_DEVICE void consumer_release(uint32_t stage, uint32_t skip = false) { + detail::pipeline_check_is_consumer(params_.role); empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); } }; @@ -841,7 +959,7 @@ namespace PipelineDetail { template class PipelineAsync { -public : +public: static constexpr uint32_t Stages = Stages_; using SharedStorage = PipelineDetail::PipelineAsyncSharedStorage; using FullBarrier = typename SharedStorage::FullBarrier; @@ -864,34 +982,47 @@ public : uint32_t dst_blockid = cute::block_rank_in_cluster(); }; - // Default assumption when only storage is passed is : - // => single producer, single consumer & they are in the same block (within the Cluster) + static CUTLASS_DEVICE - PipelineAsync(SharedStorage& storage) - : PipelineAsync(storage, {}) {} + void + init_barriers(SharedStorage& storage, Params params) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); + } + } + template CUTLASS_DEVICE PipelineAsync( SharedStorage& storage, - Params const& params) : + Params const& params, + InitBarriers = {}) : params_(params), full_barrier_ptr_(&storage.full_barrier_[0]), empty_barrier_ptr_(&storage.empty_barrier_[0]) { - int warp_idx = canonical_warp_idx_sync(); - int lane_predicate = cute::elect_one_sync(); - - // Barrier FULL, EMPTY init - // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate == 1) { - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(params.producer_arv_count); - empty_barrier_ptr_[i].init(params.consumer_arv_count); - } + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_); } - cutlass::arch::fence_barrier_init(); } + CUTLASS_DEVICE + PipelineAsync( + SharedStorage& storage, + Params const& params) : + PipelineAsync(storage, params, cute::true_type{}) { } + + // Default assumption when only storage is passed is : + // => single producer, single consumer & they are in the same block (within the Cluster) + CUTLASS_DEVICE + PipelineAsync(SharedStorage& storage) + : PipelineAsync(storage, {}, cute::true_type{}) {} + //////////////////// // Producer APIs //////////////////// @@ -983,6 +1114,7 @@ public : CUTLASS_DEVICE ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -992,6 +1124,7 @@ public : CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { empty_barrier_ptr_[stage].wait(phase); } @@ -999,11 +1132,13 @@ public : CUTLASS_DEVICE void producer_commit(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); full_barrier_ptr_[stage].arrive(); } CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -1013,6 +1148,7 @@ public : CUTLASS_DEVICE ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -1022,6 +1158,7 @@ public : CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase) { + detail::pipeline_check_is_consumer(params_.role); bool done = full_barrier_ptr_[stage].test_wait(phase); if (!done) { full_barrier_ptr_[stage].wait(phase); @@ -1030,6 +1167,7 @@ public : CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { full_barrier_ptr_[stage].wait(phase); } @@ -1037,6 +1175,7 @@ public : CUTLASS_DEVICE void consumer_release(uint32_t stage) { + detail::pipeline_check_is_consumer(params_.role); empty_barrier_ptr_[stage].arrive(params_.dst_blockid); } }; @@ -1075,7 +1214,7 @@ class OrderedSequenceBarrier { uint32_t group_size; }; -private : +private: // In future this Params object can be replaced easily with a CG object Params params_; Barrier *barrier_ptr_; @@ -1110,7 +1249,6 @@ private : } } } - cutlass::arch::fence_barrier_init(); } // Wait on a stage to be unlocked diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index ba1f74011f..13e018db8c 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -106,7 +106,11 @@ #include #include #else -#include +#include +#include +#include +#include +#include #endif #if !defined(__CUDACC_RTC__) @@ -134,6 +138,10 @@ #define CUTLASS_OS_WINDOWS #endif +#if defined(__clang__) && defined(__CUDA__) +#define CUTLASS_CLANG_CUDA 1 +#endif + /****************************************************************************** * Macros ******************************************************************************/ @@ -298,30 +306,13 @@ namespace platform { #if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -/// std::integral_constant -template -struct integral_constant; - -/// std::integral_constant -template -struct integral_constant { - static const value_t value = V; - - typedef value_t value_type; - typedef integral_constant type; - - CUTLASS_HOST_DEVICE operator value_type() const { return value; } - - CUTLASS_HOST_DEVICE const value_type operator()() const { return value; } -}; - #else -using std::integral_constant; using std::pair; #endif +using CUTLASS_STL_NAMESPACE::integral_constant; using CUTLASS_STL_NAMESPACE::bool_constant; using CUTLASS_STL_NAMESPACE::true_type; using CUTLASS_STL_NAMESPACE::false_type; diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index aa4e3f1a12..e878156277 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -35,15 +35,14 @@ #pragma once #if defined(__CUDACC_RTC__) -#include #include #else -#include -#include +#include #endif -#include "cutlass/cutlass.h" +#include +#include "cutlass/cutlass.h" #include "cutlass/platform/platform.h" namespace cutlass { diff --git a/include/cutlass/real.h b/include/cutlass/real.h index e53301b3ff..95a22444f0 100644 --- a/include/cutlass/real.h +++ b/include/cutlass/real.h @@ -35,6 +35,8 @@ #pragma once +#include // CUTLASS_DEVICE + namespace cutlass { /// Used to determine the real-valued underlying type of a numeric type T. diff --git a/include/cutlass/reduction/thread/reduction_operators.h b/include/cutlass/reduction/thread/reduction_operators.h index ba62c1b50e..8423c2d933 100644 --- a/include/cutlass/reduction/thread/reduction_operators.h +++ b/include/cutlass/reduction/thread/reduction_operators.h @@ -172,7 +172,7 @@ struct ReduceArrayOperation, uint1b_t, N> { item = (item || !bits); } - return uint1b_t(!item); + return uint1b_t{!item}; } }; @@ -195,7 +195,7 @@ struct ReduceArrayOperation, uint1b_t, N> { item = (item || bits); } - return uint1b_t(item); + return uint1b_t{item}; } }; diff --git a/include/cutlass/tensor_view_planar_complex.h b/include/cutlass/tensor_view_planar_complex.h index c98de563ff..af63f80cd3 100644 --- a/include/cutlass/tensor_view_planar_complex.h +++ b/include/cutlass/tensor_view_planar_complex.h @@ -48,6 +48,7 @@ #include "cutlass/cutlass.h" #include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view.h" // cutlass::TensorView namespace cutlass { diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 8e7ab884cf..d6d265a430 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -40,6 +40,7 @@ #include #include #include +#include // std::memcpy #endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/transform/device/transform_universal_adapter.hpp b/include/cutlass/transform/device/transform_universal_adapter.hpp index c7ab0ceb07..a5033d80eb 100644 --- a/include/cutlass/transform/device/transform_universal_adapter.hpp +++ b/include/cutlass/transform/device/transform_universal_adapter.hpp @@ -59,7 +59,7 @@ template class TransformUniversalAdapter { public: - using TransformKernel = TransformKernel_; + using TransformKernel = GetUnderlyingKernel_t; using Arguments = typename TransformKernel::Arguments; using Params = typename TransformKernel::Params; static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; diff --git a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp index 0ae7bab062..dd4fa0c14a 100644 --- a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +++ b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -267,7 +267,9 @@ class SM90StructuredSparseCompressor { case 3: return 0b11; default: + CUTLASS_ASSERT(false); CUTE_GCC_UNREACHABLE; + return 0b00; } }; diff --git a/media/docs/cute/04_algorithms.md b/media/docs/cute/04_algorithms.md index 353e9cc474..f427e5ef81 100644 --- a/media/docs/cute/04_algorithms.md +++ b/media/docs/cute/04_algorithms.md @@ -100,7 +100,7 @@ void copy(Tensor const& src, // Any logical shape Tensor & dst) // Any logical shape { - for (int i = 0; i < size(src); ++i) { + for (int i = 0; i < size(dst); ++i) { dst(i) = src(i); } } diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 97ed6a631f..29e5a0f6f3 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -185,7 +185,6 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GP ``` **NVIDIA Ampere Architecture.** - ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture ``` diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index f9723c4651..fad278370f 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -57,6 +57,19 @@ def _cuda_install_path_from_nvcc() -> str: # Alias CUTLASS_PATH as source_path source_path = CUTLASS_PATH +_NVCC_VERSION = None +def nvcc_version(): + global _NVCC_VERSION + if _NVCC_VERSION is None: + import subprocess + + # Attempt to get NVCC version + result = subprocess.run(['nvcc', '--version'], capture_output=True) + if result.returncode != 0: + raise Exception('Unable to run `nvcc --version') + _NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0] + return _NVCC_VERSION + _CUDA_INSTALL_PATH = None def cuda_install_path(): """ diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index d72af78eae..95e264cd89 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -139,7 +139,7 @@ def get_tile_scheduler_arguments_3x( splits: int = 1): max_swizzle_size = 1 raster_order_option = 0 # Heuristic - if tile_scheduler == TileSchedulerType.Persistent: + if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]: return _PersistentTileSchedulerArguments( max_swizzle_size, raster_order_option, diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index f52b181831..2c38397d3f 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -90,7 +90,7 @@ def get_str(self): opts.append(f"--include-path={incl}") arch_flag = f"-arch=sm_{self.arch}" - if self.arch == 90: + if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12: arch_flag += "a" opts.append(arch_flag) @@ -237,7 +237,7 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op if incl not in includes: includes.append(incl) - includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes + includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes for incl in includes: source_buffer_device += SubstituteTemplate( IncludeTemplate, diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass/backend/epilogue.py index e0b5e9574f..48366a7609 100644 --- a/python/cutlass/backend/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -44,6 +44,7 @@ dtype2ctype = { DataType.f16: ctypes.c_uint16, + DataType.bf16: ctypes.c_uint16, DataType.f32: ctypes.c_float, DataType.f64: ctypes.c_double, DataType.s8: ctypes.c_int8, diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index 575767d03f..153b937e65 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -59,18 +59,21 @@ def max(x, dim): elif is_torch_tensor(x): return torch.amax(x, dim) + def maximum(x, y): if is_numpy_tensor(x): return np.maximum(x, y) elif is_torch_tensor(x): return torch.maximum(x, torch.tensor(y)) - + + def minimum(x, y): if is_numpy_tensor(x): return np.minimum(x, y) elif is_torch_tensor(x): return torch.minimum(x, torch.tensor(y)) + ############################################################################## # Layout manipulate nodes ############################################################################## diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 7c16cc6855..2a02f61c90 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -51,6 +51,20 @@ # Strip any additional information from the CUDA version _cuda_version = __version__.split("rc")[0] +# Check that Python CUDA version exceeds NVCC version +_nvcc_version = cutlass.nvcc_version() +_cuda_list = _cuda_version.split('.') +_nvcc_list = _cuda_version.split('.') +for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): + if int(val_cuda) < int(val_nvcc): + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}") + +if len(_nvcc_list) > len(_cuda_list): + if len(_nvcc_list) != len(_cuda_list) + 1: + raise Exception(f"Malformatted NVCC version of {_nvcc_version}") + if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}") + class KernelsForDataType: """ @@ -278,7 +292,7 @@ def __init__( ] manifest_args = cutlass_library.generator.define_parser().parse_args(args) manifest = cutlass_library.manifest.Manifest(manifest_args) - generate_function(manifest, _cuda_version) + generate_function(manifest, _nvcc_version) if operation_kind not in manifest.operations: # No kernels generated for this architecture, this could be because the CUDA diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 6f8483ed82..62a5474ae7 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -818,6 +818,7 @@ def emit(self, operation): element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index e6a9f9e8e5..bd06a8016a 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -177,7 +177,7 @@ def CreateGemmUniversal3xOperator( complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Persistent]): + tile_schedulers=[TileSchedulerType.Default]): if type(data_types) is dict: data_types = [data_types] @@ -226,7 +226,7 @@ def CreateSparseGemmUniversal3xOperator( complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Persistent]): + tile_schedulers=[TileSchedulerType.Default]): if type(data_types) is dict: data_types = [data_types] @@ -1048,7 +1048,7 @@ def CreateConvOperator3x(manifest: Manifest, schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \ [(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)], complex_transforms: Optional[Sequence[ComplexTransform]] = None, - tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Persistent], + tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default], conv_kind: ConvKind = ConvKind.Fprop, log_indent_level: int = 1): """ @@ -6508,6 +6508,7 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): data_type, alignment_constraints, BlasMode.hermitian) # + ################################################################################################### def GenerateSM90_Conv3x(manifest, cuda_version, @@ -6703,6 +6704,7 @@ def GenerateSM90_Conv3x(manifest, cuda_version, product( ( ConvKind.Dgrad, + ConvKind.Wgrad ), spatial_dims, ( diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index be9eef20ed..3ccfb403ff 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -75,6 +75,7 @@ class DataType(enum.Enum): u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() + s2 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() @@ -92,11 +93,13 @@ class DataType(enum.Enum): cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() + cs2 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() + cu2 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() @@ -126,6 +129,7 @@ class DataType(enum.Enum): DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", + DataType.s2: "s2", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", @@ -143,11 +147,13 @@ class DataType(enum.Enum): DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", + DataType.cu2: "cu2", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", + DataType.cs2: "cs2", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", @@ -164,6 +170,7 @@ class DataType(enum.Enum): DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", + DataType.s2: "cutlass::int2b_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", @@ -181,11 +188,13 @@ class DataType(enum.Enum): DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", + DataType.cu2: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", + DataType.cs2: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", @@ -202,6 +211,7 @@ class DataType(enum.Enum): DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, + DataType.s2: 2, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, @@ -219,11 +229,13 @@ class DataType(enum.Enum): DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, + DataType.cu2: 4, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, + DataType.cs2: 4, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 08fcd547d7..021406d700 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -492,6 +492,21 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if not (is_fp8 and is_sparse): schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue]) stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 0): + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + # Inconsistency: fp8 pingpong only gets stamped out with fast accum + if not is_fp8 or level >= 1: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecialized + ]) + if can_do_fp8_fast_accum: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecialized + ]) if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue @@ -526,17 +541,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, # persistent kernels with TMA epilogues if can_do_tma_epilogue: assert not requires_transposed_epilogue - # Inconsistency: fp8 pingpong only gets stamped out with fast accum - if not is_fp8 or level >= 1: - schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpong, - EpilogueScheduleType.TmaWarpSpecialized - ]) - if can_do_fp8_fast_accum: - schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecialized - ]) if can_do_cooperative: # Sparse kernels only support FastAccum FP8 mainloop if not (is_fp8 and is_sparse): diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index 3f9996cfcf..da6c1dec6a 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -118,6 +118,5 @@ def evt_func_call2(accum, C, alpha, beta): result_keys = ["D"] launcher.verify((m, n, k), input_keys, result_keys, l) - if __name__ == '__main__': unittest.main() diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt index a90707468f..a576868b0a 100644 --- a/test/self_contained_includes/CMakeLists.txt +++ b/test/self_contained_includes/CMakeLists.txt @@ -131,6 +131,115 @@ set(header_files_to_check cute/atom/mma_traits_sm80.hpp cute/atom/mma_traits_sm90.hpp cute/atom/mma_traits_sm90_gmma.hpp + # cutlass + cutlass/aligned_buffer.h + cutlass/array.h + cutlass/array_planar_complex.h + cutlass/array_subbyte.h + cutlass/barrier.h + cutlass/bfloat16.h + cutlass/blas3.h + cutlass/blas3_types.h + cutlass/block_striped.h + cutlass/cluster_launch.hpp + cutlass/complex.h + cutlass/constants.h + cutlass/coord.h + cutlass/core_io.h + cutlass/cuda_host_adapter.hpp + cutlass/cutlass.h + cutlass/device_kernel.h + cutlass/fast_math.h + cutlass/float8.h + # cutlass/floating_point_nvrtc.h + cutlass/functional.h + cutlass/gemm_coord.h + cutlass/gemm_coord.hpp + cutlass/half.h + cutlass/integer_subbyte.h + cutlass/kernel_hardware_info.h + cutlass/kernel_hardware_info.hpp + cutlass/kernel_launch.h + cutlass/matrix.h + cutlass/matrix_coord.h + cutlass/matrix_shape.h + cutlass/numeric_conversion.h + cutlass/numeric_size.h + cutlass/numeric_types.h + cutlass/pitch_linear_coord.h + cutlass/predicate.h + cutlass/predicate_vector.h + cutlass/quaternion.h + cutlass/real.h + cutlass/relatively_equal.h + cutlass/semaphore.h + cutlass/subbyte_reference.h + cutlass/tensor_coord.h + cutlass/tensor_ref.h + cutlass/tensor_ref_planar_complex.h + cutlass/tensor_view.h + cutlass/tensor_view_planar_complex.h + cutlass/tfloat32.h + cutlass/trace.h + cutlass/uint128.h + cutlass/version.h + cutlass/wmma_array.h + cutlass/workspace.h + # cutlass/platform + cutlass/platform/platform.h + + # cutlass/pipeline + cutlass/pipeline/pipeline.hpp + cutlass/pipeline/sm90_pipeline.hpp + # cutlass/detail + cutlass/detail/cluster.hpp + cutlass/detail/collective.hpp + cutlass/detail/dependent_false.hpp + cutlass/detail/helper_macros.hpp + cutlass/detail/layout.hpp + cutlass/detail/mainloop_fusion_helper_bgrada.hpp + cutlass/detail/mma.hpp + # cutlass/arch + cutlass/arch/arch.h + cutlass/arch/barrier.h + cutlass/arch/cache_operation.h + cutlass/arch/config.h + cutlass/arch/custom_abi.h + cutlass/arch/grid_dependency_control.h + cutlass/arch/memory.h + # cutlass/arch/memory_sm75.h + # cutlass/arch/memory_sm80.h + cutlass/arch/mma.h + # cutlass/arch/mma_sm50.h + # cutlass/arch/mma_sm60.h + # cutlass/arch/mma_sm61.h + # cutlass/arch/mma_sm70.h + # cutlass/arch/mma_sm75.h + # cutlass/arch/mma_sm80.h + # cutlass/arch/mma_sm89.h + # cutlass/arch/mma_sm90.h + cutlass/arch/mma_sparse_sm80.h + cutlass/arch/mma_sparse_sm89.h + # cutlass/arch/simd.h + # cutlass/arch/simd_sm60.h + # cutlass/arch/simd_sm61.h + cutlass/arch/reg_reconfig.h + cutlass/arch/tma_operation.h + cutlass/arch/wmma.h + # cutlass/arch/wmma_sm70.h + # cutlass/arch/wmma_sm72.h + # cutlass/arch/wmma_sm75.h + # cutlass/arch/wmma_sm80.h + # cutlass/layout + cutlass/layout/layout.h + cutlass/layout/matrix.h + cutlass/layout/permute.h + cutlass/layout/pitch_linear.h + cutlass/layout/tensor.h + cutlass/layout/tensor_op_multiplicand_sm70.h + cutlass/layout/tensor_op_multiplicand_sm75.h + cutlass/layout/tensor_op_multiplicand_sm80.h + cutlass/layout/vector.h ) # for each header in _header_files: diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 7c4864f684..b02ec65a73 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -63,7 +63,7 @@ set(CUTLASS_TEST_UNIT_RESULTS_CACHE_DIR ${CMAKE_CURRENT_LIST_DIR}/data/hashes) function(cutlass_test_unit_add_executable NAME) - set(options WITHOUT_CUDA) + set(options WITHOUT_CUDA DO_NOT_LOWERCASE_TEST_NAME) set(oneValueArgs) set(multiValueArgs TEST_SETS_SUPPORTED EXTRA_INCLUDE_DIRS) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -109,14 +109,22 @@ function(cutlass_test_unit_add_executable NAME) set(CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS --gtest_output=xml:${NAME_STEM}.gtest.xml) + if (__DO_NOT_LOWERCASE_TEST_NAME) + set(DO_NOT_LOWERCASE_TEST_NAME DO_NOT_LOWERCASE_TEST_NAME) + else() + set(DO_NOT_LOWERCASE_TEST_NAME) + endif() + cutlass_add_executable_tests( ${NAME_STEM} ${NAME} TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED} TEST_COMMAND_OPTIONS CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS ${RESULT_CACHE_FILE_ARGS} + ${DO_NOT_LOWERCASE_TEST_NAME} ) endfunction() + add_custom_target(cutlass_test_unit) add_custom_target(test_unit) diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 5171eb5cf4..32acad1ec4 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -87,7 +87,6 @@ void FilterArchitecture() { << " [" << cudaGetErrorString(err) << "]" << std::endl; exit(1); } - cudaDeviceProp deviceProperties; err = cudaGetDeviceProperties(&deviceProperties, cudaDeviceId); if (cudaSuccess != err) { diff --git a/test/unit/conv/device_3x/conv_problem_sizes.hpp b/test/unit/conv/device_3x/conv_problem_sizes.hpp index ef651712d7..d66de64a04 100644 --- a/test/unit/conv/device_3x/conv_problem_sizes.hpp +++ b/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -1159,6 +1159,37 @@ std::vector> get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { using ProblemShape = cutlass::conv::ConvProblemShape; std::vector problem_shapes; + // Test TMA truncation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 512, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {2}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1024, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 2048, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {8}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); // non-packed input/output strides. // stride divides dilation // asymmetric padding diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index 3227f3d631..b392165c36 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -336,10 +336,17 @@ struct ConvTestbed { // Scale if constexpr (cute::is_same_v> || - cute::is_same_v>) { + cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> ) { fusion_args.activation.scale = ElementCompute{1}; } + // LeakyRelu + if constexpr (cute::is_same_v> ) { + fusion_args.activation.leaky_alpha = ElementCompute{0}; + } + cutlass::Status status = cutlass::Status::kInvalid; status = conv_op.can_implement(args); @@ -617,8 +624,9 @@ bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f for (DecompositionMode decomp_mode : decomposition_modes) { std::vector problem_splits = {Splits{1}}; if constexpr (UsesStreamKScheduler) { - if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + if (decomp_mode == DecompositionMode::SplitK) { problem_splits.push_back(Splits{2}); + problem_splits.push_back(Splits{4}); } } for (auto splits : problem_splits) { diff --git a/test/unit/core/float8.cu b/test/unit/core/float8.cu index 6fd044852d..14d9d22bc0 100644 --- a/test/unit/core/float8.cu +++ b/test/unit/core/float8.cu @@ -35,6 +35,7 @@ #include "../common/cutlass_unit_test.h" #include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" #include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/cute/ampere/CMakeLists.txt b/test/unit/cute/ampere/CMakeLists.txt index 6ac7f2f203..c1a654e893 100644 --- a/test/unit/cute/ampere/CMakeLists.txt +++ b/test/unit/cute/ampere/CMakeLists.txt @@ -28,7 +28,7 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_ampere - cp_async.cu + cp_sync.cu ldsm.cu cooperative_gemm.cu cooperative_copy.cu diff --git a/test/unit/cute/ampere/cooperative_copy.cu b/test/unit/cute/ampere/cooperative_copy.cu index a91000cb91..fef61aa238 100644 --- a/test/unit/cute/ampere/cooperative_copy.cu +++ b/test/unit/cute/ampere/cooperative_copy.cu @@ -46,6 +46,7 @@ #include // cute::Swizzle #include // cute::compose(cute::Swizzle) #include +#include using namespace cute; @@ -71,7 +72,7 @@ cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layo Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), gmem_layout); Tensor s_tensor = make_tensor(make_smem_ptr(smem), smem_layout); - cooperative_copy(threadIdx.x, g_in_tensor, s_tensor); + cooperative_copy(threadIdx.x, g_in_tensor, s_tensor, AutoCopyAsync{}); cp_async_fence(); cp_async_wait<0>(); @@ -84,7 +85,7 @@ cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layo } __syncthreads(); - cooperative_copy(threadIdx.x, s_tensor, g_out_tensor); + cooperative_copy(threadIdx.x, s_tensor, g_out_tensor, AutoCopyAsync{}); } // ss --> shared to shared @@ -106,7 +107,7 @@ cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Lay Tensor s1_tensor = make_tensor(make_smem_ptr(smem1), layout2); Tensor s2_tensor = make_tensor(make_smem_ptr(smem2), layout1); - cooperative_copy>(threadIdx.x, g_in_tensor, s1_tensor); + cooperative_copy>(threadIdx.x, g_in_tensor, s1_tensor, AutoCopyAsync{}); cp_async_fence(); cp_async_wait<0>(); @@ -119,10 +120,10 @@ cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Lay } __syncthreads(); - cooperative_copy(threadIdx.x, s1_tensor, s2_tensor); + cooperative_copy(threadIdx.x, s1_tensor, s2_tensor, AutoCopyAsync{}); __syncthreads(); - cooperative_copy>(threadIdx.x, s2_tensor, g_out_tensor); + cooperative_copy>(threadIdx.x, s2_tensor, g_out_tensor, AutoCopyAsync{}); } // gg --> global to global @@ -135,7 +136,7 @@ cooperative_copy_default_gg(T const* g_in, T* g_out, Layout1 const& layout1, Lay Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), layout1); Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), layout2); - cooperative_copy(threadIdx.x, g_in_tensor, g_out_tensor); + cooperative_copy(threadIdx.x, g_in_tensor, g_out_tensor, AutoCopyAsync{}); } template @@ -252,7 +253,7 @@ typedef testing::Types< std::tuple>, std::tuple>, std::tuple>, - std::tuple>, + std::tuple> > CooperativeCopyModeMaxVecBitsList; TYPED_TEST_SUITE(SM80_CuTe_Ampere, CooperativeCopyModeMaxVecBitsList); diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index 2ba01933e2..5bb6ecd2de 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -40,406 +40,462 @@ using namespace cute; TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - constexpr uint32_t thread_block_size = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA) { - using value_type = double; - - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - constexpr uint32_t thread_block_size = 128; + using value_type = double; - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm3_Half_MMA_CustomSmemLayouts) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 128; - constexpr uint32_t n = 128; - constexpr uint32_t k = 128; - constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = + auto shape_mnk = Shape<_128, _128, _128>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` - >; - - using smem_a_atom_layout_t = Layout, Stride< _1,_64>>; - using smem_b_atom_layout_t = Layout, Stride<_32, _1>>; - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - - test_cooperative_gemm_col_major_layout{}; + + auto smem_a_atom_layout = Layout, Stride< _1,_64>>{}; + auto smem_b_atom_layout = Layout, Stride<_32, _1>>{}; + auto smem_c_atom_layout = make_layout(select<0,1>(shape_mnk)); + + test_cooperative_gemm_col_major_layout(); + value_type> + (smem_a_atom_layout, + smem_b_atom_layout, + smem_c_atom_layout, + shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm4_Half_MMA_SwizzledSmemLayouts) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 128; - constexpr uint32_t n = 128; - constexpr uint32_t k = 128; - constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = + auto shape_mnk = Shape<_128, _128, _128>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` - >; + >{}; // RowMajor - using smem_rowmajor_atom_layout_t = decltype( + auto smem_a_atom_layout = composition(Swizzle<3,3,3>{}, Layout, - Stride<_64, _1>>{})); + Stride<_64, _1>>{}); // ColMajor - using smem_colmajor_atom_layout_t = decltype( + auto smem_b_atom_layout = composition(Swizzle<3,3,3>{}, Layout, - Stride< _1,_64>>{})); - using smem_a_atom_layout_t = smem_rowmajor_atom_layout_t; - using smem_b_atom_layout_t = smem_colmajor_atom_layout_t; - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); - - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - - using smem_a_atom_layout_t = smem_a_atom_layout_t; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); - - using smem_b_atom_layout_t = smem_b_atom_layout_t; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); - - using smem_c_atom_layout_t = smem_c_atom_layout_t; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); - - test_cooperative_gemm, // C - thread_block_size, - tiled_mma_t, - 128, + Stride< _1,_64>>{}); + + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); + + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{}); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); + + auto smem_a_layout = tile_to_shape( + smem_a_atom_layout, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_b_atom_layout, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_c_atom_layout, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + cute::identity{}, // TransformLoadA + cute::identity{}, // TransformLoadB + cute::identity{}, // TransformLoadC + cute::identity{}, // TransformStoreC + SM75_U32x4_LDSM_N{}, // A + SM75_U16x8_LDSM_T{}, // B + AutoVectorizingCopyWithAssumedAlignment<128>{}); // C } TEST(SM80_CuTe_Ampere, CooperativeGemm5_Double_MMA_SwizzledSmemLayouts) { - using value_type = double; - - constexpr uint32_t m = 128; - constexpr uint32_t n = 64; - constexpr uint32_t k = 16; - constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = double; - using tiled_mma_t = + auto shape_mnk = Shape<_128, _64, _16>{}; + auto tiled_mma = TiledMMA, // Atom Layout>, // Atom layout Tile, Stride<_2, _1>>, // 32x32x4 MMA with perm for load vectorization Layout, Stride<_2, _1>>, - Underscore>>; + Underscore>>{}; - using smem_a_atom_layout_t = decltype( + auto smem_a_atom_layout = composition(Swizzle<2,2,2>{}, Layout, - Stride< _1,_16>>{})); // M, K - using smem_b_atom_layout_t = decltype( + Stride< _1,_16>>{}); // M, K + auto smem_b_atom_layout = composition(Swizzle<2,2,2>{}, Layout, - Stride< _1,_16>>{})); // N, K - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); - - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - - using smem_a_atom_layout_t = smem_a_atom_layout_t; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); - using smem_b_atom_layout_t = smem_b_atom_layout_t; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); - using smem_c_atom_layout_t = smem_c_atom_layout_t; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<128>, // B - AutoVectorizingCopyWithAssumedAlignment<128>, // C - thread_block_size, - tiled_mma_t, - 128, + Stride< _1,_16>>{}); // N, K + + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); + + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{}); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); + + auto smem_a_layout = tile_to_shape( + smem_a_atom_layout, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + auto smem_b_layout = tile_to_shape( + smem_b_atom_layout, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + auto smem_c_layout = tile_to_shape( + smem_c_atom_layout, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm6_MixedPrecisionFP16FP32_MMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::half_t; using TB = cutlass::half_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm7_MixedPrecisionBF16FP32_MMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::bfloat16_t; using TB = cutlass::bfloat16_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_MMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::tfloat32_t; using TB = cutlass::tfloat32_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } -TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) { - +TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA_Dynamic) { + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; using TA = cutlass::complex; using TB = cutlass::complex; using TC = cutlass::complex; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _4, _0>>, + Tile + >{}; + + auto a_layout = make_layout(Shape,Int<35>>{}, make_stride(44, 1)); + auto b_layout = make_layout(Shape< Int<7>, Int<35>>{}, make_stride(44, 1)); + auto c_layout = make_layout(Shape, Int<7>>{}, make_stride(1, 30)); + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) { constexpr uint32_t thread_block_size = 256; constexpr int MaxVecBits = 128; + using TA = cutlass::complex; + using TB = cutlass::complex; + using TC = cutlass::complex; - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _4, _0>>, Tile - >; - - using ALayout = Layout,Int<35>>, Stride, Int<1> >>; - using BLayout = Layout, Int<35>>, Stride, Int<1> >>; - using CLayout = Layout, Int<7>>, Stride< Int<1>, Int<30>>>; - - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, - MaxVecBits, - TA, - TB, - TC>(); + >{}; + + auto a_layout = Layout,Int<35>>, Stride, Int<1> >>{}; + auto b_layout = Layout, Int<35>>, Stride, Int<1> >>{}; + auto c_layout = Layout, Int<7>>, Stride< Int<1>, Int<30>>>{}; + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm10_F16F64F16_FMA) { + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; using TA = cutlass::half_t; using TB = double; using TC = cutlass::half_t; - constexpr uint32_t thread_block_size = 256; - constexpr int MaxVecBits = 128; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom>, Layout, Stride<_1, _16, _0>>, Tile - >; - - using ALayout = Layout,Int<64>>, Stride, Int< 1>>>; - using BLayout = Layout,Int<64>>, Stride, Int<64>>>; - using CLayout = Layout,Int<64>>, Stride, Int<64>>>; - - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, + >{}; + + auto a_layout = Layout,Int<64>>, Stride, Int< 1>>>{}; + auto b_layout = Layout,Int<64>>, Stride, Int<64>>>{}; + auto c_layout = Layout,Int<64>>, Stride, Int<64>>>{}; + + test_cooperative_gemm(); + TC> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemmComposedStride) { - using T = cute::half_t; - constexpr uint32_t thread_block_size = 128; constexpr int MaxVecBits = 16; + using T = cute::half_t; - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _2, _0>>, Tile - >; + >{}; - using swizzle = cute::Swizzle<3, 3, 3>; - using offset = cute::_0; - using atom_tile_right = decltype(cute::make_layout(cute::Shape{}, cute::LayoutRight{})); - using FP16AtomLayoutRight = decltype(cute::composition(swizzle{}, offset{}, atom_tile_right{})); + auto swizzle = cute::Swizzle<3, 3, 3>{}; + auto offset = cute::_0{}; + auto atom_tile_right = cute::make_layout(cute::Shape{}, cute::LayoutRight{}); + auto FP16AtomLayoutRight = cute::composition(swizzle, offset, atom_tile_right); - using shape = cute::Shape, cute::Int<128>>; - using global_a_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{})); - using global_b_layout = decltype(cute::make_layout(shape{}, cute::LayoutLeft{})); - using global_c_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{})); + auto shape = cute::Shape, cute::Int<128>>{}; + auto global_a_layout = cute::make_layout(shape, cute::LayoutRight{}); + auto global_b_layout = cute::make_layout(shape, cute::LayoutLeft{}); + auto global_c_layout = cute::make_layout(shape, cute::LayoutRight{}); // This is for A row major, B col major according to CUTLASS default configs - using ALayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_a_layout{})); - using BLayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_b_layout{})); - using CLayout = global_c_layout; - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, + auto a_layout = cute::tile_to_shape(FP16AtomLayoutRight, global_a_layout); + auto b_layout = cute::tile_to_shape(FP16AtomLayoutRight, global_b_layout); + auto c_layout = global_c_layout; + + test_cooperative_gemm(); + T, T, T> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } -TEST(SM89_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) { +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) { + constexpr uint32_t thread_block_size = 64; + constexpr uint32_t max_vec_bits = 16; using TA = cutlass::tfloat32_t; using TB = cutlass::tfloat32_t; using TC = float; - constexpr uint32_t m = 9; - constexpr uint32_t n = 9; - constexpr uint32_t k = 9; + auto shape_mnk = Shape, C<9>, C<9>>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout + (shape_mnk, tiled_mma, cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{}); +} +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_TransformPrecision) { constexpr uint32_t thread_block_size = 64; + constexpr uint32_t max_vec_bits = 16; + using InputTA = cutlass::half_t; + using InputTB = cutlass::half_t; + using InputTC = cutlass::half_t; + + using ComputeTA = cutlass::tfloat32_t; + using ComputeTB = cutlass::tfloat32_t; + using ComputeTC = float; + + auto shape_mnk = Shape, C<9>, C<9>>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout + (shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_TransformPrecisionReg) { + constexpr uint32_t thread_block_size = 64; + constexpr uint32_t max_vec_bits = 16; + using InputTA = cutlass::half_t; + using InputTB = cutlass::half_t; + using InputTC = cutlass::half_t; + + using ComputeTA = cutlass::tfloat32_t; + using ComputeTB = cutlass::tfloat32_t; + using ComputeTC = float; - using tiled_mma_t = + auto shape_mnk = Shape, C<9>, C<9>>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c + (shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA_Reg) { + using value_type = cutlass::half_t; + + auto shape_mnk = Shape<_64, _64, _64>{}; + + constexpr uint32_t thread_block_size = 128; + + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Reg) { + constexpr uint32_t thread_block_size = 128; + using value_type = double; + + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Predicated_Reg) { + constexpr uint32_t thread_block_size = 128; + using value_type = double; + + auto shape_mnk = Shape, C<62>, C<62>>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; - test_cooperative_gemm_col_major_layout(cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{}); + test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); } diff --git a/test/unit/cute/ampere/cp_async.cu b/test/unit/cute/ampere/cp_sync.cu similarity index 97% rename from test/unit/cute/ampere/cp_async.cu rename to test/unit/cute/ampere/cp_sync.cu index 871e8ff98b..f50454107f 100644 --- a/test/unit/cute/ampere/cp_async.cu +++ b/test/unit/cute/ampere/cp_sync.cu @@ -69,14 +69,12 @@ test2(double const* g_in, double* g_out) copy(g_tensor, s_tensor); - cp_async_fence(); - cp_async_wait<0>(); __syncthreads(); g_out[threadIdx.x] = 2 * smem[threadIdx.x]; } -TEST(SM80_CuTe_Ampere, CpAsync) +TEST(SM80_CuTe_Ampere, CpSync) { constexpr int count = 32; thrust::host_vector h_in(count); diff --git a/test/unit/cute/cooperative_gemm_common.hpp b/test/unit/cute/cooperative_gemm_common.hpp index 5dec22ca0c..dbb85e6ba3 100644 --- a/test/unit/cute/cooperative_gemm_common.hpp +++ b/test/unit/cute/cooperative_gemm_common.hpp @@ -54,57 +54,181 @@ struct fp64_tester> { using value_type = complex; }; -template // logical shape (M, N) +auto host_generate_gemm_inputs( + ALayout a_layout, + BLayout b_layout, + CLayout c_layout +) { + thrust::host_vector h_a(cosize(a_layout)); + thrust::host_vector h_b(cosize(b_layout)); + thrust::host_vector h_c(cosize(c_layout)); + thrust::host_vector h_c_out(cosize(c_layout)); + + auto h_a_tensor = make_tensor(h_a.data(), a_layout); + auto h_b_tensor = make_tensor(h_b.data(), b_layout); + auto h_c_tensor = make_tensor(h_c.data(), c_layout); + size_t max_size = std::max({static_cast(size(a_layout)), + static_cast(size(b_layout)), + static_cast(size(c_layout))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(a_layout)) { + h_a_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(b_layout)) { + h_b_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(c_layout)) { + h_c_tensor(i) = static_cast((di*di) / size(a_layout)); + } + } + + return std::make_tuple(h_a, h_b, h_c, h_c_out); +} + +template +thrust::host_vector +host_reference_gemm(Alpha alpha, + Tensor const& h_a_tensor, + Tensor const& h_b_tensor, + Beta beta, + Tensor const& h_c_tensor, + ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) + { + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TA = remove_cv_t; + using TB = remove_cv_t; + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + return h_c_ref; +} + +template +void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, + cute::Tensor const& h_c_ref_tensor) +{ + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + for (int i = 0; i < size(h_c_ref_tensor); i++) { + ABC_64 h_c_ref_i = h_c_ref_tensor(i); + ABC_64 h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + + +template + class CStoreTransform, + class SMemCopyOpA, + class SMemCopyOpB, + class SMemCopyOpC> __launch_bounds__(ThreadBlockSize) __global__ void -cooperative_gemm_kernel(TA const* a, - TB const* b, - TC* c, - TC* c_out, - Alpha const alpha, - Beta const beta, +cooperative_gemm_kernel(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + Alpha const alpha, + Beta const beta, + TiledMma tiled_mma, ALoadTransform a_load_transform, BLoadTransform b_load_transform, CLoadTransform c_load_transform, - CStoreTransform c_store_transform) + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op, + SMemCopyOpC c_copy_op) { using namespace cute; - Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), ALayout{}); - Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), BLayout{}); - Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), CLayout{}); - Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), CLayout{}); + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; extern __shared__ float4 smem_buf[]; auto* smem_ptr = reinterpret_cast(smem_buf); auto* smem_ptr_a = smem_ptr; - auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(SMemALayout {})), copy_max_vec_bytes); - auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(SMemBLayout {})), copy_max_vec_bytes); + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); - Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), SMemALayout{}); - Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), SMemBLayout{}); - Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), SMemCLayout{}); + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); @@ -114,98 +238,176 @@ cooperative_gemm_kernel(TA const* a, cp_async_wait<0>(); __syncthreads(); - TiledMma tiled_mma; - cooperative_gemm( + cooperative_gemm( threadIdx.x, tiled_mma, alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, - a_load_transform, b_load_transform, c_load_transform, c_store_transform + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_copy_op, b_copy_op, c_copy_op ); __syncthreads(); cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); } -template -void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) -{ - using gmem_a_layout_t = ALayout; - using gmem_b_layout_t = BLayout; - using gmem_c_layout_t = CLayout; + class TiledMma, + class ALoadTransform, + class BLoadTransform, + class CLoadTransform, + class CStoreTransform, + class SMemCopyOpA, + class SMemCopyOpB> +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op) + { + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Create C fragment for storing intermediate results + auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); + Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); + Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); + Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); + + // Create indexing help for predicated GEMMs + Tensor cC = make_identity_tensor(shape(gmem_c_layout)); + Tensor tCcC = thr_mma.partition_C(cC); + + // Load C from global + // (always loading in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + r_c_partition(i) = c_load_transform(g_c_partition(i)); + } + } - using smem_a_layout_t = SMemALayout; - using smem_b_layout_t = SMemBLayout; - using smem_c_layout_t = SMemCLayout; + cooperative_gemm( + threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, + a_load_transform, b_load_transform, a_copy_op, b_copy_op + ); + __syncthreads(); + + // Store C to global + // (always storing in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + g_c_out_partition(i) = c_store_transform(r_c_partition(i)); + } + } +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}, + CSMemCopyOp c_smem_copy_op = {}) +{ static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - using tester = fp64_tester; - using ABC_64 = typename tester::value_type; - static_assert(size<0>(gmem_a_layout_t{}) == size<0>(gmem_c_layout_t{})); // AM == CM - static_assert(size<0>(gmem_b_layout_t{}) == size<1>(gmem_c_layout_t{})); // BN == CN - static_assert(size<1>(gmem_a_layout_t{}) == size<1>(gmem_b_layout_t{})); // AK == BK + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK - static_assert(size<0>(smem_a_layout_t{}) == size<0>(smem_c_layout_t{})); // AM == CM - static_assert(size<0>(smem_b_layout_t{}) == size<1>(smem_c_layout_t{})); // BN == CN - static_assert(size<1>(smem_a_layout_t{}) == size<1>(smem_b_layout_t{})); // AK == BK + static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM + static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK - static_assert(cute::size(gmem_a_layout_t {}) == cute::size(smem_a_layout_t {})); - static_assert(cute::size(gmem_b_layout_t {}) == cute::size(smem_b_layout_t {})); - static_assert(cute::size(gmem_c_layout_t {}) == cute::size(smem_c_layout_t {})); + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); #if 0 - print(" "); print("gmem: "); print(gmem_layout_t{}); print("\n"); - print(" "); print("smem: "); print(smem_layout_t{}); print("\n"); + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); #endif const auto alpha = static_cast(1.1); const auto beta = static_cast(1.2); - thrust::host_vector h_a(cosize(gmem_a_layout_t{})); - thrust::host_vector h_b(cosize(gmem_b_layout_t{})); - thrust::host_vector h_c(cosize(gmem_c_layout_t{})); - thrust::host_vector h_c_out(cosize(gmem_c_layout_t{})); - - auto h_a_tensor = make_tensor(h_a.data(), gmem_a_layout_t{}); - auto h_b_tensor = make_tensor(h_b.data(), gmem_b_layout_t{}); - auto h_c_tensor = make_tensor(h_c.data(), gmem_c_layout_t{}); - size_t max_size = std::max({static_cast(size(gmem_a_layout_t {})), - static_cast(size(gmem_b_layout_t {})), - static_cast(size(gmem_c_layout_t {}))}); - for (size_t i = 0; i < max_size; ++i) { - double di = static_cast(i); - if(i < size(gmem_a_layout_t{})) { - h_a_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); - } - if(i < size(gmem_b_layout_t{})) { - h_b_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); - } - if(i < size(gmem_c_layout_t{})) { - h_c_tensor(i) = static_cast((di*di) / size(gmem_a_layout_t{})); - } - } + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); thrust::device_vector d_a(h_a); thrust::device_vector d_b(h_b); @@ -213,220 +415,356 @@ void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) - + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) - + (sizeof(TC) * h_c.size()); + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + sizeof(TC) * h_c.size(); + + auto kernel = cooperative_gemm_kernel< - gmem_a_layout_t, gmem_b_layout_t, gmem_c_layout_t, - smem_a_layout_t, smem_b_layout_t, smem_c_layout_t, - SmemCopyOpA, SmemCopyOpB, SmemCopyOpC, - ThreadBlockSize, TiledMma, CopyMaxVecBits, + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, SMemCLayout, TA, TB, TC, decltype(alpha), decltype(beta), - ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp, CSMemCopyOp >; + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, thrust::raw_pointer_cast(d_a.data()), thrust::raw_pointer_cast(d_b.data()), thrust::raw_pointer_cast(d_c.data()), thrust::raw_pointer_cast(d_c_out.data()), alpha, beta, + tiled_mma, a_load_transform, b_load_transform, c_load_transform, - c_store_transform + c_store_transform, + a_smem_copy_op, + b_smem_copy_op, + c_smem_copy_op ); + cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { cudaError_t error = cudaGetLastError(); FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; } - thrust::host_vector h_c_ref(h_c.size(), static_cast(0.0)); - auto h_c_ref_tensor = make_tensor(h_c_ref.data(), gmem_c_layout_t{}); - // A * B - for (int k = 0; k < size<1>(h_a_tensor); k++) { - for (int m = 0; m < size<0>(h_a_tensor); m++) { - for (int n = 0; n < size<0>(h_b_tensor); n++) { - const auto a_value = a_load_transform(h_a_tensor(m, k)); - const auto b_value = b_load_transform(h_b_tensor(n, k)); - const auto a_value_fp64 = static_cast(a_value); - const auto b_value_fp64 = static_cast(b_value); - h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); - } - } - } - // C = A*B + C - for (int i = 0; i < size(h_c_ref_tensor); i++) { - const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); - const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); - h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Copy result data + h_c_out = d_c_out; + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}) +{ + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.0); + const auto beta = static_cast(1.0); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = + host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); + + + auto kernel = cooperative_gemm_kernel_rmem_c< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, + TA, TB, TC, + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + tiled_mma, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_smem_copy_op, b_smem_copy_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; } + // Copy result data h_c_out = d_c_out; - auto h_c_out_tensor = make_tensor(h_c_out.data(), gmem_c_layout_t{}); - for (int i = 0; i < size(h_c_ref_tensor); i++) { - ABC_64 h_c_ref_i = h_c_ref_tensor(i); - ABC_64 h_c_out_i = h_c_out_tensor(i); - double epsilon(0.1f); - double nonzero_floor(std::numeric_limits::min()); - bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); - ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; - } + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); } -template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) + class ShapeMNK, + class TiledMma, + class ... Ops> +void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) { - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - - using smem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - using smem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using smem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - - test_cooperative_gemm>, - AutoVectorizingCopyWithAssumedAlignment>, - AutoVectorizingCopyWithAssumedAlignment>, - ThreadBlockSize, - TiledMMAType, + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm(a_load_transform, b_load_transform, c_load_transform, c_store_transform); + TA, TB, TC> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma, + ops...); } -template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) -{ - test_cooperative_gemm_col_major_layout, T, T, T>( - a_load_transform, b_load_transform, c_load_transform, c_store_transform); -} -template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) + class SMemAtomLayoutA, + class SMemAtomLayoutB, + class SMemAtomLayoutC, + class ShapeMNK, + class TiledMma, + class ... Ops> +std::enable_if_t, + cute::is_layout, + cute::is_layout>> +test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + SMemAtomLayoutC smem_atom_layout_c, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops&& ... ops) { - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - - using smem_a_atom_layout_t = SMemAAtomLayout; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); - using smem_b_atom_layout_t = SMemBAtomLayout; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); - using smem_c_atom_layout_t = SMemCAtomLayout; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_atom_layout_c, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); - test_cooperative_gemm>, - AutoVectorizingCopyWithAssumedAlignment>, - AutoVectorizingCopyWithAssumedAlignment>, - ThreadBlockSize, - TiledMMAType, + test_cooperative_gemm(a_load_transform, b_load_transform, c_load_transform, c_store_transform); + TA, TB, TC> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + ops...); +} + + +template +void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + + test_cooperative_gemm_rmem_c + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + tiled_mma, + ops...); } -template +std::enable_if_t, + cute::is_layout>> +test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + test_cooperative_gemm_rmem_c + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + tiled_mma, + ops...); +} + +template +void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout_rmem_c, + T, T, T> + (static_cast(args)...); +} + +template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) + class ... Args> +void test_cooperative_gemm_col_major_layout(Args&& ... args) { - test_cooperative_gemm_col_major_layout, - T, - T, - T>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); + T, T, T> + (static_cast(args)...); } diff --git a/test/unit/cute/core/inverse_left.cpp b/test/unit/cute/core/inverse_left.cpp index 363cb17f98..142d80fb6a 100644 --- a/test/unit/cute/core/inverse_left.cpp +++ b/test/unit/cute/core/inverse_left.cpp @@ -104,7 +104,7 @@ TEST(CuTe_core, Inverse_left) auto layout = Layout, Stride<_4, _1>>{}; - test_left_inverse(filter(layout)); + test_left_inverse(layout); } { diff --git a/test/unit/cute/hopper/cooperative_gemm.cu b/test/unit/cute/hopper/cooperative_gemm.cu index c4e2274dfe..7d992510d6 100644 --- a/test/unit/cute/hopper/cooperative_gemm.cu +++ b/test/unit/cute/hopper/cooperative_gemm.cu @@ -44,91 +44,74 @@ using namespace cute; #if USE_FP8 TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF8) { + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 16; using TA = uint8_t; using TB = uint8_t; using TC = uint32_t; - constexpr uint32_t thread_block_size = 128; - constexpr int MaxVecBits = 16; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _2, _0>>, Tile<_32, _32, _32> - >; + >{}; - using swizzle = Swizzle<2, 4, 3>; + auto swizzle = Swizzle<2, 4, 3>{}; // This is for A row major, B col major according to CUTLASS default configs - using ALayout = decltype(composition(swizzle{}, Layout, Stride<_64, _1>>{})); - using BLayout = decltype(composition(swizzle{}, Layout, Stride<_1, _64>>{})); - - using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{})); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, - MaxVecBits, - TA, - TB, - TC>(); + auto a_layout = composition(swizzle, Layout, Stride<_64, _1>>{}); + auto b_layout = composition(swizzle, Layout, Stride<_1, _64>>{}); + auto c_layout = make_layout(Shape<_64, _64>{}, LayoutLeft{}); + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } #else TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF16) { + constexpr uint32_t thread_block_size = 64; + constexpr int max_vec_bits = 16; using TA = half_t; using TB = half_t; using TC = half_t; - constexpr uint32_t thread_block_size = 64; - constexpr int MaxVecBits = 16; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _0, _0>>, Tile<_32, _32, _32> - >; - - using swizzle = Swizzle<3, 3, 3>; + >{}; // This is for A row major, B col major according to CUTLASS default configs - using ALayout = decltype(composition(swizzle{}, - Layout, Stride<_64, _1>>{})); - - using BLayout = decltype(composition(swizzle{}, - Layout, Stride<_1, _64>>{})); - - using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{})); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, - MaxVecBits, + auto swizzle = Swizzle<3, 3, 3>{}; + auto ALayout = composition(swizzle{}, Layout, Stride<_64, _1>>{}); + auto BLayout = composition(swizzle{}, Layout, Stride<_1, _64>>{}); + auto CLayout = make_layout(Shape<_64, _64>{}, LayoutLeft{}); + + test_cooperative_gemm(); - + TC> + + (ALayout, + BLayout, + CLayout, + ALayout, + BLayout, + CLayout, + tiled_mma); } #endif diff --git a/test/unit/cute/hopper/tma_load_testbed.hpp b/test/unit/cute/hopper/tma_load_testbed.hpp index 0c8ed91d69..58d19e4aa1 100644 --- a/test/unit/cute/hopper/tma_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_load_testbed.hpp @@ -131,7 +131,7 @@ tma_test_device_cute(T const* g_in, T* g_out, for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); if (threadIdx.x == 0) { diff --git a/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp index 2fb88de50d..bca378793f 100644 --- a/test/unit/cute/hopper/tma_mcast_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -146,7 +146,7 @@ tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); if (elect_one_thr) { diff --git a/test/unit/cute/turing/cooperative_gemm.cu b/test/unit/cute/turing/cooperative_gemm.cu index 14ea967074..1bda5cf77f 100644 --- a/test/unit/cute/turing/cooperative_gemm.cu +++ b/test/unit/cute/turing/cooperative_gemm.cu @@ -38,21 +38,19 @@ using namespace cute; TEST(SM75_CuTe_Turing, CooperativeGemm1_MixedPrecisionFP16FP32_MMA) { + + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::half_t; using TB = cutlass::half_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = make_shape(_64{}, _64{}, _64{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } diff --git a/test/unit/cute/volta/cooperative_gemm.cu b/test/unit/cute/volta/cooperative_gemm.cu index 157913f9e5..54cf4f2214 100644 --- a/test/unit/cute/volta/cooperative_gemm.cu +++ b/test/unit/cute/volta/cooperative_gemm.cu @@ -40,105 +40,85 @@ using namespace cute; TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA) { - using value_type = float; - - constexpr uint32_t m = 64; - constexpr uint32_t n = 32; - constexpr uint32_t k = 16; constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(_64{}, _32{}, _16{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication) { - using value_type = float; - - constexpr uint32_t m = 88; - constexpr uint32_t n = 20; - constexpr uint32_t k = 12; constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(C<88>{}, C<20>{}, C<12>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication2) { - using value_type = float; - - constexpr uint32_t m = 88; - constexpr uint32_t n = 36; - constexpr uint32_t k = 24; constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(C<88>{}, C<36>{}, C<24>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication3) { - using value_type = float; - - constexpr uint32_t m = 67; - constexpr uint32_t n = 13; - constexpr uint32_t k = 11; - constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(C<67>{}, C<13>{}, C<11>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm2_DoubleFMA) { - using value_type = double; - - constexpr uint32_t m = 16; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - constexpr uint32_t thread_block_size = 128; + using value_type = double; - using tiled_mma_t = + auto shape_mnk = make_shape(C<16>{}, C<32>{}, C<32>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) { - using value_type = float; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; constexpr uint32_t thread_block_size = 256; + using value_type = float; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom< UniversalFMA >, @@ -154,228 +134,188 @@ TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) { >, Underscore > - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm4_Half_MMA) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - constexpr uint32_t thread_block_size = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; - - using smem_a_atom_layout_t = typename tiled_mma_t::AtomLayoutB_TV; - using smem_b_atom_layout_t = typename tiled_mma_t::AtomLayoutA_TV; - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - - test_cooperative_gemm_col_major_layout(); + >{}; + + auto smem_a_atom_layout = typename decltype(tiled_mma)::AtomLayoutB_TV{}; + auto smem_b_atom_layout = typename decltype(tiled_mma)::AtomLayoutA_TV{}; + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm_col_major_layout + (smem_a_atom_layout, + smem_b_atom_layout, + smem_c_atom_layout, + shape_mnk, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; - - using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - - using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<128>, // B - AutoVectorizingCopyWithAssumedAlignment<128>, // C - thread_block_size, - tiled_mma_t, - 128, + >{}; + + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto smem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto smem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA_Predicated) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 31; - constexpr uint32_t n = 27; - constexpr uint32_t k = 17; constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 16; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(C<31>{}, C<27>{}, C<17>{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; - - using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - - using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<16>, // B - AutoVectorizingCopyWithAssumedAlignment<16>, // C - thread_block_size, - tiled_mma_t, - 16, + >{}; + + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto smem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto smem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm6_Half_MAA_SwizzledSmemLayouts) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 128; - constexpr uint32_t n = 128; - constexpr uint32_t k = 64; constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_128{}, _128{}, _64{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; - - using smem_a_atom_layout_t = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride<_64, _1>>{})); - using smem_b_atom_layout_t = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride< _1,_64>>{})); - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); - - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - - using smem_a_atom_layout_t = smem_a_atom_layout_t; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); - - // Transposed - using smem_b_atom_layout_t = smem_b_atom_layout_t; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); - - using smem_c_atom_layout_t = smem_c_atom_layout_t; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<128>, // B - AutoVectorizingCopyWithAssumedAlignment<128>, // C - thread_block_size, - tiled_mma_t, - 128, + >{}; + + auto smem_a_atom_layout = composition(Swizzle<3,3,3>{}, Layout, Stride<_64, _1>>{}); + auto smem_b_atom_layout = composition(Swizzle<3,3,3>{}, Layout, Stride< _1,_64>>{}); + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); + + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{}); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); + + auto smem_a_layout = tile_to_shape( + smem_a_atom_layout, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_b_atom_layout, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_c_atom_layout, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_FMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 64; using TA = float; using TB = float; using TC = double; - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; auto aload = cute::negate {}; auto bload = cute::negate {}; auto cload = cute::negate {}; auto cstore = cute::negate {}; - test_cooperative_gemm_col_major_layout( - aload, bload, cload, cstore); + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma, aload, bload, cload, cstore); } TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_MMA) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; constexpr uint32_t thread_block_size = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; auto aload = cute::negate {}; auto bload = cute::negate {}; auto cload = cute::negate {}; auto cstore = cute::negate {}; - test_cooperative_gemm_col_major_layout( - aload, bload, cload, cstore); + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma, aload, bload, cload, cstore); } template @@ -398,26 +338,25 @@ struct convert_to { }; TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformCustomOp_FMA) { + + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 64; + using TA = float; using TB = float; using TC = double; - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; auto aload = increment_by_x{1.111f}; auto bload = convert_to {}; auto cload = cute::negate {}; auto cstore = cute::negate {}; - test_cooperative_gemm_col_major_layout( - aload, bload, cload, cstore); + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma, aload, bload, cload, cstore); } diff --git a/test/unit/cute/volta/vectorization_auto.cu b/test/unit/cute/volta/vectorization_auto.cu index b378f8b329..585abf0e26 100644 --- a/test/unit/cute/volta/vectorization_auto.cu +++ b/test/unit/cute/volta/vectorization_auto.cu @@ -67,7 +67,6 @@ kernel(GmemTensor gC, RmemTiler tiler, CopyPolicy policy) // NOTE: only 1 thread, this thread produce a block of 8x8 output. The fringe will not be touched. //copy(rC, tCgC); // Enable auto-vectorization if static - //copy_vec(rC, tCgC); // Disable auto-vectorization always copy(policy, rC, tCgC); // Use a policy to establish vectorization assumptions } diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 488c6bfa6f..87b6e53d50 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -26,52 +26,30 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -add_custom_target( - cutlass_test_unit_gemm_device - DEPENDS - cutlass_test_unit_gemm_device_simt - cutlass_test_unit_gemm_device_tensorop_sm70 - cutlass_test_unit_gemm_device_tensorop_sm75 - cutlass_test_unit_gemm_device_tensorop_f16_sm80 - cutlass_test_unit_gemm_device_tensorop_f32_sm80 - cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80 - cutlass_test_unit_gemm_device_tensorop_f64 - cutlass_test_unit_gemm_device_tensorop_s32_sm80 - cutlass_test_unit_gemm_device_wmma - cutlass_test_unit_gemm_device_tensorop_planar_complex - cutlass_test_unit_gemm_device_sparse_tensorop_sm80 - cutlass_test_unit_gemv_device - cutlass_test_unit_gemm_device_tensorop_sm90 - cutlass_test_unit_sparse_gemm_device_tensorop_sm90 - cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 -) +add_custom_target(cutlass_test_unit_gemm_device) +add_custom_target(test_unit_gemm_device) -add_custom_target( - test_unit_gemm_device - DEPENDS - test_unit_gemm_device_simt - test_unit_gemm_device_tensorop_sm70 - test_unit_gemm_device_tensorop_sm75 - test_unit_gemm_device_tensorop_f16_sm80 - test_unit_gemm_device_tensorop_f32_sm80 - test_unit_gemm_device_tensorop_f32_tf32_sm80 - test_unit_gemm_device_tensorop_f64 - test_unit_gemm_device_tensorop_s32_sm80 - test_unit_gemm_device_wmma - test_unit_gemm_device_tensorop_planar_complex - test_unit_gemm_device_sparse_tensorop_sm80 - test_unit_gemv_device - test_unit_gemm_device_tensorop_sm90 -) +################################################################################ -add_custom_target( - cutlass_test_unit_gemm_device_sm90 - DEPENDS - cutlass_test_unit_gemm_device_tensorop_sm90 - cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 -) +function(cutlass_test_unit_gemm_device_add_deps NAME) + string(REGEX REPLACE "^cutlass_" "" TEST_NAME "${NAME}") + add_dependencies(cutlass_test_unit_gemm_device ${NAME}) + add_dependencies(test_unit_gemm_device ${TEST_NAME}) +endfunction() + +function(cutlass_test_unit_gemm_device_add_executable NAME) + cutlass_test_unit_add_executable(${NAME} ${ARGN} DO_NOT_LOWERCASE_TEST_NAME) + cutlass_test_unit_gemm_device_add_deps(${NAME}) +endfunction() + +function(cutlass_test_unit_gemm_device_add_executable_split_file NAME) + cutlass_test_unit_add_executable_split_file(${NAME} ${ARGN} DO_NOT_LOWERCASE_TEST_NAME) + cutlass_test_unit_gemm_device_add_deps(${NAME}) +endfunction() -cutlass_test_unit_add_executable( +################################################################################ + +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_simt BATCH_SOURCES ON @@ -126,7 +104,9 @@ cutlass_test_unit_add_executable( gemm_splitk_simt_sm50.cu ) -cutlass_test_unit_add_executable( +list(APPEND CUTLASS_TEST_UNIT_GEMM_DEVICE_LIST cutlass_test_unit_gemm_device_simt) + +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_simt_3x BATCH_SOURCES ON @@ -139,8 +119,7 @@ cutlass_test_unit_add_executable( sm61_gemm_s8_s8_s32_simt.cu ) - -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm70 BATCH_SOURCES ON @@ -159,7 +138,7 @@ cutlass_test_unit_add_executable( gemm_splitk_tensor_op_sm70.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm75 BATCH_SOURCES ON @@ -204,7 +183,7 @@ cutlass_test_unit_add_executable( ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f16_sm80 BATCH_SOURCES ON @@ -214,7 +193,7 @@ cutlass_test_unit_add_executable( gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_sm80 BATCH_SOURCES ON @@ -236,7 +215,7 @@ cutlass_test_unit_add_executable( gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_sm80_3x sm80_gemm_s8_s8_s32_tensor_op.cu @@ -245,7 +224,7 @@ cutlass_test_unit_add_executable( ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 BATCH_SOURCES ON @@ -286,13 +265,14 @@ cutlass_test_unit_add_executable( gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90 BATCH_SOURCES ON BATCH_SIZE 4 sm90_gemm_f16_f16_f16_tensor_op.cu + sm90_gett_f16_f16_f16_tensor_op.cu sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu sm90_gemm_s8_s8_s8_tensor_op_s32.cu sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -302,7 +282,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f8_f8_f8_tensor_op_fp32.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_stream_k sm90_gemm_stream_k_scheduler.cu @@ -311,7 +291,7 @@ cutlass_test_unit_add_executable( ) # Alignment tests -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_alignx_sm90 BATCH_SOURCES ON @@ -336,14 +316,14 @@ cutlass_test_unit_add_executable( ) # Ptr Array test -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu ) # Group Gemm test -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu @@ -351,31 +331,25 @@ cutlass_test_unit_add_executable( # Sparse tests # Sparse kernels trigger an ICE in gcc 7.5 -if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) -cutlass_test_unit_add_executable( - cutlass_test_unit_sparse_gemm_device_tensorop_sm90 - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu - sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu - sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu - sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu -) -else() -cutlass_test_unit_add_executable( - cutlass_test_unit_sparse_gemm_device_tensorop_sm90 +if (NOT (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) + + cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_sparse_gemm_device_tensorop_sm90 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu + sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu + sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu + sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu + ) - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 -) endif() # Fused epilogue tests -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_sm90 BATCH_SOURCES ON @@ -400,7 +374,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative_evt.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 BATCH_SOURCES ON @@ -412,7 +386,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90 BATCH_SOURCES ON @@ -423,7 +397,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80 BATCH_SOURCES ON @@ -443,7 +417,7 @@ cutlass_test_unit_add_executable( sm80_gemm_f16_f16_f32_tensor_op_f32.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f64 BATCH_SOURCES ON @@ -471,7 +445,7 @@ cutlass_test_unit_add_executable( gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_s32_sm80 BATCH_SOURCES ON @@ -493,7 +467,7 @@ cutlass_test_unit_add_executable( gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_wmma BATCH_SOURCES ON @@ -551,7 +525,7 @@ cutlass_test_unit_add_executable( gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_planar_complex BATCH_SOURCES ON @@ -562,7 +536,7 @@ cutlass_test_unit_add_executable( gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm89 BATCH_SOURCES ON @@ -574,7 +548,7 @@ cutlass_test_unit_add_executable( # gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped BATCH_SOURCES ON @@ -583,7 +557,7 @@ cutlass_test_unit_add_executable( gemm_grouped_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped_scheduler BATCH_SOURCES ON @@ -592,7 +566,7 @@ cutlass_test_unit_add_executable( gemm_grouped_scheduler_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped_rank_2k_scheduler BATCH_SOURCES ON @@ -601,7 +575,7 @@ cutlass_test_unit_add_executable( rank_2k_grouped_scheduler_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_sparse_tensorop_sm80 BATCH_SOURCES ON @@ -622,7 +596,7 @@ cutlass_test_unit_add_executable( gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemv_device BATCH_SOURCES ON @@ -631,42 +605,22 @@ cutlass_test_unit_add_executable( gemv.cu ) -if (NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUTLASS_NVCC_DEVICE_COMPILE) -add_dependencies( - cutlass_test_unit_gemm_device - cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + + gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu + gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu + + gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu ) -add_dependencies( - test_unit_gemm_device - test_unit_gemm_device_gemm_with_fused_epilogue_tensorop - ) - -cutlass_test_unit_add_executable( - cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop - - gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu - gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu - - gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu -) - endif() -if (NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUTLASS_NVCC_DEVICE_COMPILE) -add_dependencies( - cutlass_test_unit_gemm_device - cutlass_test_unit_gemm_device_blas3 - ) - -add_dependencies( - test_unit_gemm_device - test_unit_gemm_device_blas3 - ) - -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_blas3 BATCH_SOURCES ON @@ -833,7 +787,7 @@ cutlass_test_unit_add_executable( hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped_blas3 BATCH_SOURCES ON @@ -858,13 +812,12 @@ cutlass_test_unit_add_executable( endif() -if (NOT CUDA_COMPILER MATCHES "[Cc]lang") - -cutlass_test_unit_add_executable( - cutlass_test_unit_gemm_device_broadcast +if (CUTLASS_NVCC_DEVICE_COMPILE) - gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu -) + cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_broadcast + gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu + ) endif() diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 3d5a586dd2..3a6cf0b2eb 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -39,6 +39,7 @@ #include #include #include +#include // std::lcm #include "../../common/cutlass_unit_test.h" #include "cutlass/util/host_tensor.h" @@ -55,6 +56,7 @@ #include "cutlass/complex.h" #include "cutlass/transform/device/transform_universal_adapter.hpp" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/detail/collective.hpp" #include "testbed_utils.h" @@ -151,6 +153,12 @@ struct ElementScalarType +struct IsSfdEpi : cute::false_type {}; + +template +struct IsSfdEpi> : cute::true_type {}; + // The maximum swizzle size to use // // This class, like Splits above makes it harder to confuse @@ -1140,7 +1148,6 @@ struct HostCollectiveEpilogue { static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && (cute::is_same_v || cute::is_same_v); - using Arguments = typename Gemm::GemmKernel::EpilogueArguments; /// Initialization @@ -1454,6 +1461,22 @@ struct HostCollectiveEpilogue { bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); if(!passed) { + #if 0 + auto [M, N, K, L] = problem_shape_MNKL; + auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { + printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); + } + } + } + } + #endif std::cout<<"D is incorrect"<) { + fusion_args.beta = beta.at(coord_0); + fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr + } if constexpr (IsPerRowScaleEnabled) { int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; @@ -1620,6 +1646,7 @@ struct HostCollectiveEpilogue { // example of how to set kernel activation arguments // see ActivationFunctor::Arguments in activation.h for definition // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { fusion_args.activation.scale = ElementCompute(1); } @@ -1713,6 +1740,7 @@ struct HostCollectiveEpilogue { decltype(Vbeta), ActivationFunctor, cutlass::plus + , false /*PerColumnBias_*/ > epilogue_params{}; epilogue_params.C = C; @@ -1779,6 +1807,7 @@ struct TestbedImpl { using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, HostCollectiveDefaultEpilogue, HostCollectiveEpilogue>; @@ -2004,7 +2033,7 @@ struct TestbedImpl { return false; } } - catch (std::exception const& e) { + catch ([[maybe_unused]] std::exception const& e) { CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); throw; } diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index b7d1c57923..479102b32c 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -346,7 +346,7 @@ struct HostCollectiveMainloop { stride_b_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); for(int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -380,7 +380,7 @@ struct HostCollectiveMainloop { Arguments to_args(ProblemShapeType problem_shapes) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); std::vector ptr_A_host(L); std::vector ptr_B_host(L); @@ -587,7 +587,7 @@ struct HostCollectiveDefaultEpilogue { stride_d_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -649,7 +649,7 @@ struct HostCollectiveDefaultEpilogue { ElementScalar beta, int batch) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); tensors_D[batch].sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); @@ -678,7 +678,7 @@ struct HostCollectiveDefaultEpilogue { Arguments to_args(ProblemShapeType problem_shapes) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); std::vector ptr_C_host(L); std::vector ptr_D_host(L); @@ -724,8 +724,8 @@ struct HostCollectiveDefaultEpilogue { // // Allocate the GEMM workspace // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + L = std::max(problem_shapes.groups(), L); auto coord_0 = cutlass::make_Coord(0); auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), @@ -905,9 +905,8 @@ struct HostCollectiveEpilogue { references_D.clear(); stride_c_host.clear(); stride_d_host.clear(); - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = std::max(problem_shapes.groups(), L); for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -1118,7 +1117,6 @@ struct HostCollectiveEpilogue { passed &= tmp; } } - return passed; } @@ -1189,7 +1187,7 @@ struct HostCollectiveEpilogue { Arguments to_args(ProblemShapeType problem_shapes) { auto coord_0 = cutlass::make_Coord(0); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = std::max(problem_shapes.groups(), L); std::vector ptr_C_host(L); std::vector ptr_D_host(L); @@ -1220,19 +1218,22 @@ struct HostCollectiveEpilogue { device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); } + auto device_tensors_C_ptr = cute::is_void_v ? nullptr : + reinterpret_cast(device_tensors_C.get()); + Arguments arguments; if constexpr (IsGroupGemm) { arguments = { {}, - device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() }; } else { arguments = { {}, - device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] }; } @@ -1252,7 +1253,9 @@ struct HostCollectiveEpilogue { fusion_args.beta = beta.at(coord_0); fusion_args.alpha_ptr = alpha.device_data(); - fusion_args.beta_ptr = beta.device_data(); + // can_implement requires beta_ptr to not be set if its voidC + fusion_args.beta_ptr = cute::is_void_v ? nullptr : + beta.device_data(); if constexpr (IsScaleFactorEnabled) { fusion_args.scale_a = scale_A.at(coord_0); @@ -1316,7 +1319,8 @@ struct HostCollectiveEpilogue { // // Allocate the GEMM workspace // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto [M, N, K, L] = problem_shape_MNKL; auto coord_0 = cutlass::make_Coord(0); auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); @@ -1338,7 +1342,6 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); - cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, @@ -1518,7 +1521,7 @@ struct TestbedImpl { { using namespace cute; auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = std::max(problem_shapes.groups(), L); bool passed = true; for (int32_t i = 0; i < L; ++i) { @@ -1760,7 +1763,7 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative cutlass::DeviceAllocation problem_sizes_device; for (int i = 0; i < batch; ++i) { - problem_sizes_host.push_back({m, n, k}); + problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); } problem_sizes_device.reset(problem_sizes_host.size()); diff --git a/test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu b/test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu new file mode 100644 index 0000000000..4d03fc939e --- /dev/null +++ b/test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu @@ -0,0 +1,184 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/reference/device/gett.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gett_f16t_f16n_f16n_tensor_op_gmma_f16, 8x8x8x8x8x8) { + + using BatModeStrides = int; + + using RowModeStridesA = cute::Stride; + using RedModeStrides = cute::Stride; + + using ColModeStridesB = cute::Stride; + + using RowModeStridesC = cute::Stride; + using ColModeStridesC = cute::Stride; + + using StrideA = cute::Stride; + using StrideB = cute::Stride; + using StrideC = cute::Stride; + using StrideD = StrideC; + + using TileShape = Shape, Shape<_8, _8>, Shape<_8, _8>>; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, StrideA, 8, + cutlass::half_t, StrideB, 8, + cutlass::half_t, + TileShape, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, StrideC, 8, + cutlass::half_t, StrideC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GettKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Shape, + Shape, + int>, + CollectiveOp, + CollectiveEpilogue + >; + + using Gett = cutlass::gemm::device::GemmUniversalAdapter; + + auto problem_shape = make_shape( + make_shape(32,8), + make_shape(32,4), + make_shape(32,2), + 1 + ); + + auto [M, N, K, L] = problem_shape; + + StrideA dA = make_stride(make_stride(64, 2048), make_stride(_1{}, 32), size(M) * size(K)); + StrideB dB = make_stride(make_stride(64, 2048), make_stride(_1{}, 32), size(N) * size(K)); + StrideC dC = make_stride(make_stride(_1{}, 32), make_stride(256, 8192), size(M) * size(N)); + StrideD dD = dC; + + cutlass::half_t alpha = cutlass::half_t(1.0f); + cutlass::half_t beta = cutlass::half_t(1.0f); + + thrust::host_vector A_h(size(M) * size(K) * size(L)); + thrust::host_vector B_h(size(N) * size(K) * size(L)); + thrust::host_vector C_h(size(M) * size(N) * size(L)); + thrust::host_vector D_h(size(M) * size(N) * size(L)); + thrust::host_vector D_h_ref(size(M) * size(N) * size(L)); + + for (auto& a : A_h) a = cutlass::half_t(static_cast(4 * (rand() / double(RAND_MAX) - 1))); + for (auto& b : B_h) b = cutlass::half_t(static_cast(4 * (rand() / double(RAND_MAX) - 1))); + for (auto& c : C_h) c = cutlass::half_t(static_cast(4 * (rand() / double(RAND_MAX) - 1))); + for (auto& d : D_h) d = cutlass::half_t(-1); + for (auto& d : D_h_ref) d = cutlass::half_t(-1); + + thrust::device_vector A = A_h; + thrust::device_vector B = B_h; + thrust::device_vector C = C_h; + thrust::device_vector D = D_h; + thrust::device_vector D_ref = D_h_ref; + + typename Gett::Arguments args { + cutlass::gemm::GemmUniversalMode::kBatched, + problem_shape, + {A.data().get(), dA, B.data().get(), dB}, + { {alpha, beta}, C.data().get(), dC, D.data().get(), dD} + }; + + Gett gett; + auto status = gett(args); + EXPECT_TRUE(status == cutlass::Status::kSuccess); + auto cuda_err = cudaDeviceSynchronize(); + + EXPECT_TRUE(cuda_err == cudaSuccess); + + cutlass::reference::device::gett( + problem_shape, + A.data().get(), dA, + B.data().get(), dB, + cutlass::half_t(0.0f), + C.data().get(), dC, + D_ref.data().get(), dD, + alpha, beta); + + cuda_err = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_err == cudaSuccess); + + bool passed = cutlass::reference::device::BlockCompareEqual( + D.data().get(), D_ref.data().get(), D_ref.size()); + EXPECT_TRUE(passed); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/transform/CMakeLists.txt b/test/unit/transform/CMakeLists.txt index 4912eca2c3..0ab0b93f50 100644 --- a/test/unit/transform/CMakeLists.txt +++ b/test/unit/transform/CMakeLists.txt @@ -33,7 +33,7 @@ add_custom_target( cutlass_test_unit_transform DEPENDS cutlass_test_unit_transform_threadblock - cutlass_test_unit_transform_filter_format + cutlass_test_unit_transform_kernel ) add_custom_target( diff --git a/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp index 3a42d74ffc..8ec0c4ac46 100644 --- a/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp +++ b/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp @@ -45,6 +45,7 @@ #include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor #include "cutlass/arch/arch.h" // cutlass::arch::Sm90 #include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t #include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up #include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo @@ -219,9 +220,7 @@ class SM90StructuredSparseCompressorLegacy { // * EltA using ElementA = ElementA_; using ElementAUint = cute::uint_bit_t>; - static constexpr bool IsRuntimeDataTypeA = cute::is_same_v || - cute::is_same_v || - cute::is_same_v; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); using ArrayElementA = cute::conditional_t>, ElementA>; diff --git a/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp index af50348e8c..03e4fa75b1 100644 --- a/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp +++ b/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp @@ -60,6 +60,7 @@ #include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride #include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals #include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill +#include "cutlass/detail/collective.hpp" #include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor #include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE diff --git a/test/unit/transform/kernel/CMakeLists.txt b/test/unit/transform/kernel/CMakeLists.txt index d337b31ed9..92d4a47bdb 100644 --- a/test/unit/transform/kernel/CMakeLists.txt +++ b/test/unit/transform/kernel/CMakeLists.txt @@ -27,6 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. cutlass_test_unit_add_executable( - cutlass_test_unit_transform_filter_format + cutlass_test_unit_transform_kernel filter_format_transformer.cu ) diff --git a/test/unit/util/rms_norm.cu b/test/unit/util/rms_norm.cu index c366f7e088..4111406714 100644 --- a/test/unit/util/rms_norm.cu +++ b/test/unit/util/rms_norm.cu @@ -104,7 +104,7 @@ void run_test(int M, int N) { for (int n = 0; n < N; ++n) { auto diff = abs(static_cast(output_ref.at({m, n}) - output.at({m, n}))); mean_abs_diff += diff; - max_abs_diff = max(max_abs_diff, diff); + max_abs_diff = cutlass::platform::max(max_abs_diff, diff); } } diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index f55b5131e8..d87d0895b8 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -72,6 +72,8 @@ class Handle { /// Pointer to the most recently executed operation Operation const *last_operation_; + int device_idx_; + public: /// Constructor diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 56e6e455de..a3af54ba26 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -118,7 +118,8 @@ class Operation { void const *arguments, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const = 0; + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const = 0; }; @@ -272,6 +273,8 @@ struct GemmUniversalConfiguration { int64_t ldb{0}; int64_t ldc{0}; int64_t ldd{0}; + + int device_count{1}; }; struct GemmUniversalArguments { @@ -303,6 +306,8 @@ struct GemmUniversalArguments { int sm_count{0}; library::RasterOrder raster_order{}; int swizzle_size{1}; + + int device_index{0}; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/conv2d_operation.h b/tools/library/src/conv2d_operation.h index cf29c889dc..027b2615f1 100644 --- a/tools/library/src/conv2d_operation.h +++ b/tools/library/src/conv2d_operation.h @@ -326,7 +326,12 @@ class Conv2dOperation : public Conv2dOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -578,7 +583,12 @@ class DirectConv2dOperation : public Conv2dOperation { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/conv3d_operation.h b/tools/library/src/conv3d_operation.h index 758866b8f7..6cb1796b5a 100644 --- a/tools/library/src/conv3d_operation.h +++ b/tools/library/src/conv3d_operation.h @@ -317,7 +317,12 @@ class Conv3dOperation : public Conv3dOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/conv_operation_3x.hpp b/tools/library/src/conv_operation_3x.hpp index 0093182201..d6f79e9196 100644 --- a/tools/library/src/conv_operation_3x.hpp +++ b/tools/library/src/conv_operation_3x.hpp @@ -236,12 +236,14 @@ class ConvOperation3x : public Operation { typename Operator::Arguments out_args{}; status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); return status; } auto* in_args_ptr = reinterpret_cast(arguments); status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); return status; } @@ -332,7 +334,8 @@ class ConvOperation3x : public Operation { void const* arguments, void* host_workspace, void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const override + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const override { auto status = Status::kInvalid; @@ -358,7 +361,7 @@ class ConvOperation3x : public Operation { } auto* op = reinterpret_cast(host_workspace); - return op->run(out_args, device_workspace, stream); + return op->run(out_args, device_workspace, stream, nullptr, launch_with_pdl); } private: @@ -482,6 +485,11 @@ class ConvOperation3x : public Operation { typename Operator::Arguments& out_args, Conv2dConfiguration const& config) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv2dConfiguration)\n"); +#endif using detail::vector_to_array_strides; constexpr int num_spatial_dims = Operator::NumSpatialDimensions; @@ -595,6 +603,7 @@ class ConvOperation3x : public Operation { const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); + const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); // cutlass::library::Conv2dConfiguration has no member stride_d. // The code below imitates the testbed, @@ -605,12 +614,57 @@ class ConvOperation3x : public Operation { CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); return Status::kInvalid; } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // This means that stride_act isn't always config.stride_A, + // depending on Fprop / Dgrad / Wgrad. The code here "undoes" + // the logic in Conv2dWorkspace::set_stride_vector so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + problem_shape_type problem_shape( /* mode = */ mode, - /* shape_act = */ {N, H, W, C}, - /* stride_act = */ stride_A, - /* shape_flt = */ {K, R, S, C}, - /* stride_flt = */ stride_B, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, /* lower_padding = */ {pad_h, pad_w}, /* upper_padding = */ {pad_h, pad_w}, /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, @@ -620,9 +674,11 @@ class ConvOperation3x : public Operation { // ConvProblemShape's constructor sets its shape_C member. #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " problem_shape:\n" - << " shape_C: " << problem_shape.shape_C << "\n"; - std::cerr << " stride_C: " << problem_shape.stride_C << "\n"; + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); #endif // Initialization of C's and D's strides follows the CUTLASS 3 // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). @@ -670,6 +726,11 @@ class ConvOperation3x : public Operation { typename Operator::Arguments& out_args, Conv3dConfiguration const& config) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv3dConfiguration)\n"); +#endif using detail::coord_to_array_strides; constexpr int num_spatial_dims = Operator::NumSpatialDimensions; @@ -762,6 +823,10 @@ class ConvOperation3x : public Operation { print_stride(input_stride_b, "input_stride_b"); print_stride(input_stride_c, "input_stride_c"); #endif + // Conv3dConfiguration stores the strides as Coord (with + // compile-time size), so there's no need to check sizes here + // (unlike Conv2dConfiguration, which stores strides as + // std::vector). constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; using problem_shape_type = @@ -771,18 +836,68 @@ class ConvOperation3x : public Operation { const TensorStride stride_A = coord_to_array_strides(input_stride_a); const TensorStride stride_B = coord_to_array_strides(input_stride_b); + const TensorStride stride_C = coord_to_array_strides(input_stride_c); const int num_groups = config.problem_size.groups; if (num_groups != 1) { CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); return Status::kInvalid; } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // Conv3dConfiguration differs a bit from Conv2dConfiguration, + // but the idea is the same: the "input_stride_a" from config + // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act + // isn't always input_stride_a. Analogously, stride_flt isn't + // always input_stride_b. The code here "undoes" the logic in + // config.layout_a(conv_kind) and config.layout_b(conv_kind) + // (analogous to Conv2dWorkspace::set_stride_vector) so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, D, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, T, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + problem_shape_type problem_shape( /* mode = */ mode, - /* shape_act = */ {N, D, H, W, C}, - /* stride_act = */ stride_A, - /* shape_flt = */ {K, T, R, S, C}, - /* stride_flt = */ stride_B, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, /* lower_padding = */ {pad_d, pad_h, pad_w}, /* upper_padding = */ {pad_d, pad_h, pad_w}, /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, @@ -792,15 +907,15 @@ class ConvOperation3x : public Operation { // ConvProblemShape's constructor sets its shape_C member. #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " problem_shape:\n" - << " shape_C: " << problem_shape.shape_C << "\n"; - std::cerr << " stride_C: " << problem_shape.stride_C << "\n"; + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); #endif - + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " Compute stride_C and stride_D\n"; -#endif using StrideC = typename Operator::ConvKernel::StrideC; using StrideD = typename Operator::ConvKernel::StrideD; auto stride_C = StrideC{}; @@ -845,9 +960,8 @@ class ConvOperation3x : public Operation { ConvArguments const& in_args) const { #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << "ConvOperation3x::update_operator_arguments_from_arguments\n"; + CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); #endif - auto status = UpdateFusionArgs::update_( out_args.epilogue.thread, in_args); if (status != Status::kSuccess) { diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 2c2c5c9c4e..5c6f9ca815 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -296,7 +296,12 @@ class GemmOperation : public GemmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -500,7 +505,12 @@ class GemmSparseOperation : public GemmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -721,7 +731,12 @@ class GemmUniversalOperation : public GemmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -930,7 +945,12 @@ class GemmPlanarComplexOperation : public GemmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -1133,7 +1153,12 @@ class GemmPlanarComplexArrayOperation : public GemmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -1337,7 +1362,12 @@ class GemmGroupedOperation : public GemmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index f4918b7d37..7c87b45e0f 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" #include "cutlass/library/library.h" #include "library_internal.h" #include "cutlass/gemm/dispatch_policy.hpp" @@ -331,7 +332,8 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const override { OperatorArguments args; Status status = update_arguments_(args, static_cast(arguments_ptr)); @@ -341,7 +343,7 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { Operator *op = static_cast(host_workspace); // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(args, device_workspace, stream); + status = op->run(args, device_workspace, stream, nullptr, launch_with_pdl); return status; } }; diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index cfb176beeb..e6f00f7225 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -57,14 +57,12 @@ Handle::Handle( scalar_pointer_mode_(ScalarPointerMode::kHost), last_operation_(nullptr) { - int device_idx = -1; - - cudaError_t error = cudaGetDevice(&device_idx); + cudaError_t error = cudaGetDevice(&device_idx_); if (error != cudaSuccess) { throw std::runtime_error("cudaGetDevice() failed"); } - error = cudaGetDeviceProperties(&device_, device_idx); + error = cudaGetDeviceProperties(&device_, device_idx_); if (error != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } @@ -78,8 +76,14 @@ Handle::Handle( Handle::~Handle() { if (workspace_) { - if (workspace_) { - cudaFree(workspace_); + int device_before; + cudaGetDevice(&device_before); + if (device_before != device_idx_) { + cudaSetDevice(device_idx_); + } + cudaFree(workspace_); + if (device_before != device_idx_) { + cudaSetDevice(device_before); } workspace_ = nullptr; @@ -89,6 +93,10 @@ Handle::~Handle() { /// Move constructor Handle::Handle(Handle && handle) { + cudaError_t error = cudaGetDevice(&device_idx_); + if (error != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } device_ = handle.device_; workspace_size_ = handle.workspace_size_; workspace_ = handle.workspace_; @@ -112,6 +120,8 @@ Handle & Handle::operator=(Handle && handle) { handle.workspace_ = nullptr; handle.workspace_size_ = 0; + device_idx_ = handle.device_idx_; + return *this; } @@ -151,6 +161,12 @@ void *Handle::get_workspace() const { /// Sets the size of device workspace, invalidating previous calls to get_device_workspace() void Handle::set_workspace_size(size_t bytes) { + int device_before; + cudaGetDevice(&device_before); + if (device_before != device_idx_) { + cudaSetDevice(device_idx_); + } + if (bytes != workspace_size_) { if (workspace_) { @@ -177,6 +193,9 @@ void Handle::set_workspace_size(size_t bytes) { throw std::runtime_error("Failed to clear workspace"); } } + if (device_before != device_idx_) { + cudaSetDevice(device_before); + } } /// Gets the scalar pointer mode diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 2b57dbc317..be311c6255 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -72,6 +72,10 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kB1; }; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS2; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kS4; }; @@ -92,6 +96,10 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kS64; }; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU2; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kU4; }; diff --git a/tools/library/src/rank_2k_operation.h b/tools/library/src/rank_2k_operation.h index b353a34727..5a6111041b 100644 --- a/tools/library/src/rank_2k_operation.h +++ b/tools/library/src/rank_2k_operation.h @@ -314,7 +314,12 @@ class Rank2KOperation : public Rank2KOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/rank_k_operation.h b/tools/library/src/rank_k_operation.h index e5e5ec78db..e6afb1da6d 100644 --- a/tools/library/src/rank_k_operation.h +++ b/tools/library/src/rank_k_operation.h @@ -310,7 +310,12 @@ class RankKOperation : public RankKOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/reduction/reduction_operation.h b/tools/library/src/reduction/reduction_operation.h index ceb4627891..3bcabf091c 100644 --- a/tools/library/src/reduction/reduction_operation.h +++ b/tools/library/src/reduction/reduction_operation.h @@ -231,7 +231,12 @@ class ReductionOperation : public Operation { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index ab924b5f01..2bafc4af62 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -432,7 +432,12 @@ class ConvReferenceOperation : public Operation { void const *arguments, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } ConvArguments const &args = *static_cast(arguments); diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h index fd58d4f0ac..940ff5217d 100644 --- a/tools/library/src/reference/gemm_reference_operation.h +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -192,7 +192,12 @@ class GemmReferenceOperation : public Operation { void const *arguments, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } GemmUniversalConfiguration const &config = *static_cast(host_workspace); GemmUniversalArguments const &args = *static_cast(arguments); diff --git a/tools/library/src/reference/gemm_s8_s8_s32.cu b/tools/library/src/reference/gemm_s8_s8_s32.cu index d88e986f56..8c661b98a0 100644 --- a/tools/library/src/reference/gemm_s8_s8_s32.cu +++ b/tools/library/src/reference/gemm_s8_s8_s32.cu @@ -77,7 +77,7 @@ void initialize_gemm_reference_operations_s8_s8_s32(Manifest &manifest) { int8_t, // ElementA int8_t, // ElementB int32_t, // ElementC - int32_t, // ElementScalar / ElementCompute + float, // ElementScalar / ElementCompute int32_t, // ElementAccumulator int32_t // ElementD >(manifest); diff --git a/tools/library/src/sparse_gemm_operation_3x.hpp b/tools/library/src/sparse_gemm_operation_3x.hpp index fec987f5a2..8bfc41d726 100644 --- a/tools/library/src/sparse_gemm_operation_3x.hpp +++ b/tools/library/src/sparse_gemm_operation_3x.hpp @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" #include "cutlass/library/library.h" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor #include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter @@ -56,7 +57,7 @@ namespace cutlass::library { /////////////////////////////////////////////////////////////////////////////////////////////////// -// Limitation & Assumptions: +// Limitation & Assumptions: // 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, // and lda is m if the tensor is m-major. // 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. @@ -169,7 +170,6 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { return status; } - // TODO: type erase Arguments structure in 3.0 GEMM operator_args.problem_shape = cute::make_shape( arguments->problem_size.m(), arguments->problem_size.n(), @@ -302,13 +302,15 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { } Status initialize_with_profiler_workspace( - void const *configuration, - void *host_workspace, - void *device_workspace, + void const *configuration, + void *host_workspace, + void *device_workspace, uint8_t **profiler_workspaces, int problem_count_from_profiler, cudaStream_t stream = nullptr) { + iter_idx.resize(static_cast(configuration)->device_count, 0); + // Set problem_count. problem_count = problem_count_from_profiler; @@ -319,13 +321,10 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { // * Construct Op Operator *op = new (host_op_workspace_ptr) Operator; - // * Device Full Ptr - device_full_ptr = reinterpret_cast(device_workspace); - // * Device Ptr (1st iteration) // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | // iteri : op_workspace | tensor_ac | tensor_e - auto* device_ptr_iter1 = device_full_ptr; + auto* device_ptr_iter1 = static_cast(device_workspace); auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; @@ -335,15 +334,15 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { auto* device_a_raw_ptr = profiler_workspaces[0]; // * Random fill 50% of TensorA w/ zero following the structured sparse requirement - cudaMemcpy(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); - cudaMemcpy(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); CUDA_CHECK(cudaGetLastError()); // * Compress DTensorA and get DTensorAC & DTensorE cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); typename Compressor::Arguments arguments{ {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, @@ -372,23 +371,23 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { return status; } - CUDA_CHECK(cudaStreamSynchronize(stream)); - // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE for (int iter_i = 1; iter_i < problem_count; iter_i++) { // * Device AC E Ptr per iteration // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | // iteri : op_workspace | tensor_ac | tensor_e - auto* device_ptr_iteri = device_full_ptr + device_per_iter_workspace_size * iter_i; + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; auto* device_op_workspace_ptr = device_ptr_iteri; auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; - cudaMemcpy(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice); - cudaMemcpy(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice); + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); } + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaGetLastError()); return Status::kSuccess; @@ -398,17 +397,21 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { Status run( void const *arguments_ptr, void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { + void *device_workspace, + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const override { OperatorArguments operator_args; - auto* device_ptr_iteri = device_full_ptr + device_per_iter_workspace_size * iter_idx; + + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; auto* device_op_workspace_ptr = device_ptr_iteri; auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; - iter_idx = (iter_idx + 1) % problem_count; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); @@ -418,7 +421,7 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { Operator *op = static_cast(host_workspace); // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(operator_args, device_op_workspace_ptr, stream); + status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, launch_with_pdl); return status; } @@ -426,9 +429,7 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { // Variables that must change in the const functions. mutable CompressorUtility compressor_utility; mutable int problem_count = 1; - mutable int iter_idx = 0; - - uint8_t* device_full_ptr = nullptr; + mutable std::vector iter_idx; mutable uint64_t tensor_ac_size = 0; mutable uint64_t tensor_e_size = 0; diff --git a/tools/library/src/symm_operation.h b/tools/library/src/symm_operation.h index 548356b4b6..aeb06caf54 100644 --- a/tools/library/src/symm_operation.h +++ b/tools/library/src/symm_operation.h @@ -312,7 +312,12 @@ class SymmOperation : public SymmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/trmm_operation.h b/tools/library/src/trmm_operation.h index 80e4ad14e0..88c4f7ab7d 100644 --- a/tools/library/src/trmm_operation.h +++ b/tools/library/src/trmm_operation.h @@ -304,7 +304,12 @@ class TrmmOperation : public TrmmOperationBase { void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 1a7fb12812..d71caf4183 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -100,7 +100,7 @@ install( if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 AND (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED))) set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) else() - set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) endif() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true) diff --git a/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h index 249750138c..32d79211c4 100644 --- a/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h @@ -430,7 +430,7 @@ class Conv2dOperationProfiler : public OperationProfiler { protected: /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h index 3cbf310606..2ce0a1c21b 100644 --- a/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h @@ -384,7 +384,7 @@ class Conv3dOperationProfiler : public OperationProfiler { /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/tools/profiler/include/cutlass/profiler/cublas_helpers.h index 32875245fd..10642e5ff0 100644 --- a/tools/profiler/include/cutlass/profiler/cublas_helpers.h +++ b/tools/profiler/include/cutlass/profiler/cublas_helpers.h @@ -304,7 +304,7 @@ struct cublasLtGemmExDispatcher { ); /// Executes GEMM using these arguments - cublasStatus_t operator()(cublasLtHandle_t handle); + cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); ~cublasLtGemmExDispatcher(){ diff --git a/tools/profiler/include/cutlass/profiler/enumerated_types.h b/tools/profiler/include/cutlass/profiler/enumerated_types.h index 25a42296b2..3e6efa4897 100644 --- a/tools/profiler/include/cutlass/profiler/enumerated_types.h +++ b/tools/profiler/include/cutlass/profiler/enumerated_types.h @@ -90,9 +90,9 @@ AlgorithmMode from_string(std::string const &str); /// Outcome of a performance test enum class Disposition { kPassed, - kFailed, + kFailed, // kernel itself reported an error kNotRun, - kIncorrect, + kIncorrect, // kernel finished without a detected error, but result does not equal expected result kNotVerified, kInvalidProblem, kNotSupported, diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 8e1292f986..b103e3db74 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -143,6 +143,8 @@ class GemmOperationProfiler : public OperationProfiler { /// Buffer used for the cutlass reduction operations' host workspace std::vector reduction_host_workspace; + + cudaStream_t stream; }; protected: @@ -155,7 +157,7 @@ class GemmOperationProfiler : public OperationProfiler { GemmProblem problem_; /// Device memory allocations - GemmWorkspace gemm_workspace_; + std::vector gemm_workspace_; /// CUTLASS parallel reduction operation to follow this* gemm operation library::Operation const *reduction_op_; @@ -231,7 +233,8 @@ class GemmOperationProfiler : public OperationProfiler { DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace); /// Verifies CUTLASS against host and device references bool verify_with_reference_( @@ -246,7 +249,7 @@ class GemmOperationProfiler : public OperationProfiler { /// Method to profile a CUTLASS Operation Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/gpu_timer.h b/tools/profiler/include/cutlass/profiler/gpu_timer.h index 304a362a19..815b6af172 100644 --- a/tools/profiler/include/cutlass/profiler/gpu_timer.h +++ b/tools/profiler/include/cutlass/profiler/gpu_timer.h @@ -51,16 +51,21 @@ struct GpuTimer { // GpuTimer(); + + GpuTimer(GpuTimer const&) = delete; + + GpuTimer(GpuTimer &&gpu_timer) noexcept; + ~GpuTimer(); - /// Records a start event in the stream - void start(cudaStream_t stream = nullptr); + /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags + void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - /// Records a stop event in the stream - void stop(cudaStream_t stream = nullptr); + /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags + void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - /// Records a stop event in the stream and synchronizes on the stream - void stop_and_wait(cudaStream_t stream = nullptr); + /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags + void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); /// Returns the duration in milliseconds double duration(int iterations = 1) const; diff --git a/tools/profiler/include/cutlass/profiler/operation_profiler.h b/tools/profiler/include/cutlass/profiler/operation_profiler.h index 3dfe3fcf89..7e3005fe78 100644 --- a/tools/profiler/include/cutlass/profiler/operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -232,7 +232,7 @@ class OperationProfiler { /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/performance_result.h b/tools/profiler/include/cutlass/profiler/performance_result.h index fb3393c45f..4b9a3321b2 100644 --- a/tools/profiler/include/cutlass/profiler/performance_result.h +++ b/tools/profiler/include/cutlass/profiler/performance_result.h @@ -86,6 +86,9 @@ struct PerformanceResult { /// Average runtime in ms double runtime; + /// Average runtime in ms per device + std::vector runtime_vector; + // // Members // diff --git a/tools/profiler/src/conv2d_operation_profiler.cu b/tools/profiler/src/conv2d_operation_profiler.cu index f74ffbe728..9589c0caa2 100644 --- a/tools/profiler/src/conv2d_operation_profiler.cu +++ b/tools/profiler/src/conv2d_operation_profiler.cu @@ -396,6 +396,29 @@ Status Conv2dOperationProfiler::initialize_configuration( problem_, operation_desc.conv_kind, operation_desc.A.layout, operation_desc.B.layout, operation_desc.C.layout); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + { + auto print_vector = [] (const auto& vec) { + printf("["); + for (size_t k = 0; k < vec.size(); ++k) { + cute::print(vec[k]); + if (k + 1 < vec.size()) { + printf(","); + } + } + printf("]"); + }; + + printf("\n conv_workspace_.configuration.stride_a: "); + print_vector(conv_workspace_.configuration.stride_a); + printf("\n conv_workspace_.configuration.stride_b: "); + print_vector(conv_workspace_.configuration.stride_b); + printf("\n conv_workspace_.configuration.stride_c: "); + print_vector(conv_workspace_.configuration.stride_c); + printf("\n"); + } +#endif + // initialize library::ConvArguments conv_workspace_.arguments.A = nullptr; conv_workspace_.arguments.B = nullptr; @@ -1237,7 +1260,7 @@ bool Conv2dOperationProfiler::profile( } results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &conv_workspace_.arguments, @@ -1251,7 +1274,7 @@ bool Conv2dOperationProfiler::profile( /// Method to profile a CUTLASS Operation Status Conv2dOperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, @@ -1387,7 +1410,7 @@ Status Conv2dOperationProfiler::profile_cutlass_( // Update performance result // - runtime = timer.duration(iteration); + result.runtime = timer.duration(iteration); return status; } diff --git a/tools/profiler/src/conv3d_operation_profiler.cu b/tools/profiler/src/conv3d_operation_profiler.cu index 8e8f5873d5..04d338c3ee 100644 --- a/tools/profiler/src/conv3d_operation_profiler.cu +++ b/tools/profiler/src/conv3d_operation_profiler.cu @@ -1099,7 +1099,7 @@ bool Conv3dOperationProfiler::profile( set_cutlass_operator_arguments_(); results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &conv_workspace_.arguments, @@ -1141,7 +1141,7 @@ void Conv3dOperationProfiler::set_cutlass_operator_arguments_(int problem_idx) { /// Method to profile a CUTLASS Operation Status Conv3dOperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, @@ -1248,7 +1248,7 @@ Status Conv3dOperationProfiler::profile_cutlass_( // Update performance result // - runtime = timer.duration(iteration); + result.runtime = timer.duration(iteration); return status; } diff --git a/tools/profiler/src/cublas_helpers.cu b/tools/profiler/src/cublas_helpers.cu index 7467c1db24..412b0a2461 100644 --- a/tools/profiler/src/cublas_helpers.cu +++ b/tools/profiler/src/cublas_helpers.cu @@ -656,7 +656,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle, return true; } -cublasStatus_t cublasLtGemmExDispatcher::operator()(cublasLtHandle_t handle) +cublasStatus_t cublasLtGemmExDispatcher::operator()(cublasLtHandle_t handle, cudaStream_t stream) { return cublasLtMatmul(handle, operationDesc, @@ -673,7 +673,7 @@ cublasStatus_t cublasLtGemmExDispatcher::operator()(cublasLtHandle_t handle) &heuristicResult_.algo, workspace, heuristicResult_.workspaceSize, - 0); //number of streams is set to 0 + stream); //number of streams is set to 0 } diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 4e57244e95..a1866b55fc 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -290,9 +290,8 @@ DeviceAllocation::DeviceAllocation(): capacity_(0), pointer_(nullptr), layout_(library::LayoutTypeID::kUnknown), - batch_count_(1), - device_(-1) { - + batch_count_(1) { + cudaGetDevice(&device_); } DeviceAllocation::DeviceAllocation( @@ -329,13 +328,33 @@ DeviceAllocation::DeviceAllocation( DeviceAllocation::~DeviceAllocation() { if (pointer_) { + int current_device; + cudaGetDevice(¤t_device); + + if (current_device != device_) { + cudaSetDevice(device_); + } cudaFree(pointer_); + + if (current_device != device_) { + cudaSetDevice(current_device); + } } } DeviceAllocation &DeviceAllocation::reset() { if (pointer_) { + int current_device; + cudaGetDevice(¤t_device); + + if (current_device != device_) { + cudaSetDevice(device_); + } cudaFree(pointer_); + + if (current_device != device_) { + cudaSetDevice(current_device); + } } type_ = library::NumericTypeID::kInvalid; @@ -2438,25 +2457,11 @@ void DeviceAllocation::fill_host(double val = 0.0) { cudaError_t DeviceAllocation::malloc(void** ptr, size_t size) { cudaError_t result; - int set_device_back_to = -1; - - /// When needed this sets the device to the allocation's device remembering - /// the current device so that it can be set back after the cudaMalloc is - /// performed. - if (device_ >= 0) { - int current_device; - result = cudaGetDevice(¤t_device); - if (result != cudaSuccess) { - return result; - } + int current_device; + cudaGetDevice(¤t_device); - if (current_device != device_) { - set_device_back_to = current_device; - result = cudaSetDevice(device_); - if (result != cudaSuccess) { - return result; - } - } + if (current_device != device_) { + cudaSetDevice(device_); } // This performs the cudaMalloc @@ -2465,13 +2470,8 @@ cudaError_t DeviceAllocation::malloc(void** ptr, size_t size) { return result; } - /// When needed this sets the device back to what it was when the function was - /// called. - if (set_device_back_to != -1) { - result = cudaSetDevice(set_device_back_to); - if (result != cudaSuccess) { - return result; - } + if (current_device != device_) { + cudaSetDevice(current_device); } return cudaSuccess; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 0256f0d099..1bed599f13 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -40,6 +40,7 @@ #include "cutlass/core_io.h" #include +#include #include "cutlass/profiler/cublas_helpers.h" #include "cutlass/profiler/gemm_operation_profiler.h" @@ -195,9 +196,6 @@ Status GemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) { // default value this->swizzle_size = 1; - if (this->swizzle_size <= 0) { - return Status::kErrorInvalidProblem; - } } if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) { @@ -371,31 +369,49 @@ Status GemmOperationProfiler::initialize_configuration( return status; } - gemm_workspace_.configuration.mode = problem_.mode; - gemm_workspace_.configuration.problem_size.m() = int(problem_.m); - gemm_workspace_.configuration.problem_size.n() = int(problem_.n); - gemm_workspace_.configuration.problem_size.k() = int(problem_.k); - gemm_workspace_.configuration.lda = problem_.lda; - gemm_workspace_.configuration.ldb = problem_.ldb; - gemm_workspace_.configuration.ldc = problem_.ldc; - gemm_workspace_.configuration.ldd = problem_.ldc; - - if (problem_.mode == library::GemmUniversalMode::kBatched) { - gemm_workspace_.configuration.batch_count = problem_.batch_count; - } - else { - gemm_workspace_.configuration.batch_count = problem_.split_k_slices; + const auto device_count = options.device.devices.size(); + + gemm_workspace_.clear(); + + for (size_t i = 0; i < device_count; ++i) { + cudaSetDevice(options.device.device_id(i)); + gemm_workspace_.emplace_back(); + cudaStreamCreateWithFlags(&gemm_workspace_[i].stream, cudaStreamNonBlocking); + gemm_workspace_[i].configuration.mode = problem_.mode; + gemm_workspace_[i].configuration.problem_size.m() = int(problem_.m); + gemm_workspace_[i].configuration.problem_size.n() = int(problem_.n); + gemm_workspace_[i].configuration.problem_size.k() = int(problem_.k); + gemm_workspace_[i].configuration.lda = problem_.lda; + gemm_workspace_[i].configuration.ldb = problem_.ldb; + gemm_workspace_[i].configuration.ldc = problem_.ldc; + gemm_workspace_[i].configuration.ldd = problem_.ldc; + + gemm_workspace_[i].configuration.device_count = static_cast(device_count); + gemm_workspace_[i].arguments.device_index = static_cast(i); + + if (problem_.mode == library::GemmUniversalMode::kBatched) { + gemm_workspace_[i].configuration.batch_count = problem_.batch_count; + } + else { + gemm_workspace_[i].configuration.batch_count = problem_.split_k_slices; + } + + gemm_workspace_[i].arguments.A = nullptr; + gemm_workspace_[i].arguments.B = nullptr; + gemm_workspace_[i].arguments.C = nullptr; + gemm_workspace_[i].arguments.D = nullptr; + gemm_workspace_[i].arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].arguments.beta = problem_.beta.data(); + gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].arguments.swizzle_size = problem_.swizzle_size; + gemm_workspace_[i].arguments.raster_order = problem_.raster_order; + initialize_result_(this->model_result_, options, operation_desc, problem_space); + + if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) { + return can_implement; + } } - gemm_workspace_.arguments.A = nullptr; - gemm_workspace_.arguments.B = nullptr; - gemm_workspace_.arguments.C = nullptr; - gemm_workspace_.arguments.D = nullptr; - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size; - gemm_workspace_.arguments.raster_order = problem_.raster_order; // initialize reduction operation for parallel splitKMode if (problem_.split_k_mode == library::SplitKMode::kParallel) { if (!initialize_reduction_configuration_(operation, problem)) { @@ -403,9 +419,7 @@ Status GemmOperationProfiler::initialize_configuration( } } - initialize_result_(this->model_result_, options, operation_desc, problem_space); - - return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); + return status; } /// Initializes the performance result @@ -427,6 +441,7 @@ void GemmOperationProfiler::initialize_result_( result.bytes = problem_.bytes(operation_desc); result.flops = problem_.flops(operation_desc); result.runtime = 0; + result.runtime_vector.resize(options.device.devices.size(), 0); } @@ -447,12 +462,14 @@ bool GemmOperationProfiler::initialize_reduction_configuration_( } /// initialize library::ReductionConfiguration - gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); - gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); - gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); - gemm_workspace_.reduction_configuration.ldw = problem_.ldc; - gemm_workspace_.reduction_configuration.lds = problem_.ldc; - gemm_workspace_.reduction_configuration.ldd = problem_.ldc; + for (auto &gemm_workspace : gemm_workspace_) { + gemm_workspace.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); + gemm_workspace.reduction_configuration.partitions = int(problem_.split_k_slices); + gemm_workspace.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); + gemm_workspace.reduction_configuration.ldw = problem_.ldc; + gemm_workspace.reduction_configuration.lds = problem_.ldc; + gemm_workspace.reduction_configuration.ldd = problem_.ldc; + } // find reduction operation library::ReductionFunctionalKey reduction_key( @@ -485,11 +502,6 @@ Status GemmOperationProfiler::initialize_workspace( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - if (options.device.devices.size() != 1) { - throw std::runtime_error("This operation profiler only supports a single " - "device."); - } - cudaError_t result; result = cudaSetDevice(options.device.device_id(0)); if (result != cudaSuccess) { @@ -509,98 +521,103 @@ Status GemmOperationProfiler::initialize_workspace( bool is_sparse = operation_desc.tile_description.math_instruction.opcode_class == cutlass::library::OpcodeClassID::kSparseTensorOp; - // Compute the number of copies of the problem to avoid L2 camping. - if (!options.profiling.workspace_count) { - int64_t bytes = problem_.bytes(operation_desc); - if (bytes < 3 * int64_t(options.device.properties[0].l2CacheSize)) { - gemm_workspace_.problem_count = - 1 + int((3 * int64_t(options.device.properties[0].l2CacheSize)) / bytes); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + + // Compute the number of copies of the problem to avoid L2 camping. + if (!options.profiling.workspace_count) { + int64_t bytes = problem_.bytes(operation_desc); + if (bytes < 3 * int64_t(options.device.properties[0].l2CacheSize)) { + gemm_workspace_[i].problem_count = + 1 + int((3 * int64_t(options.device.properties[0].l2CacheSize)) / bytes); + } + else { + gemm_workspace_[i].problem_count = 1; + } } else { - gemm_workspace_.problem_count = 1; + gemm_workspace_[i].problem_count = options.profiling.workspace_count; } - } - else { - gemm_workspace_.problem_count = options.profiling.workspace_count; - } - bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; - if (allocate_device_tensors) { - int seed_shift = 0; - gemm_workspace_.A = device_context.allocate_and_initialize_tensor( - options, - "A", - operation_desc.A.element, - operation_desc.A.layout, - {int(problem_.m), int(problem_.k)}, - {int(problem_.lda)}, - problem_.batch_count * gemm_workspace_.problem_count, - seed_shift++, - 0 // device_index - ); - - gemm_workspace_.B = device_context.allocate_and_initialize_tensor( - options, - "B", - operation_desc.B.element, - operation_desc.B.layout, - {int(problem_.k), int(problem_.n)}, - {int(problem_.ldb)}, - problem_.batch_count * gemm_workspace_.problem_count, - seed_shift++, - 0 // device_index - ); - - gemm_workspace_.C = device_context.allocate_and_initialize_tensor( - options, - "C", - operation_desc.C.element, - operation_desc.C.layout, - {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count, - seed_shift++, - 0 // device_index - ); - - gemm_workspace_.Computed = device_context.allocate_tensor( - options, - "D", - operation_desc.D.element, - operation_desc.D.layout, - {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count, - 0 // device_index - ); - - gemm_workspace_.Reference = device_context.allocate_tensor( - options, - "Reference", - operation_desc.D.element, - operation_desc.D.layout, - {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count, - 0 // device_index - ); - } - - if (options.execution_mode != ExecutionMode::kDryRun) { - // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels - gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; - gemm_workspace_.arguments.batch_count = problem_.batch_count; - gemm_workspace_.arguments.lda = problem_.lda; - gemm_workspace_.arguments.ldb = problem_.ldb; - gemm_workspace_.arguments.ldc = problem_.ldc; - gemm_workspace_.arguments.ldd = problem_.ldc; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; + if (allocate_device_tensors) { + int seed_shift = 0; + gemm_workspace_[i].A = device_context.allocate_and_initialize_tensor( + options, + "A", + operation_desc.A.element, + operation_desc.A.layout, + {int(problem_.m), int(problem_.k)}, + {int(problem_.lda)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + seed_shift++, + i // device_index + ); + + gemm_workspace_[i].B = device_context.allocate_and_initialize_tensor( + options, + "B", + operation_desc.B.element, + operation_desc.B.layout, + {int(problem_.k), int(problem_.n)}, + {int(problem_.ldb)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + seed_shift++, + i // device_index + ); + + gemm_workspace_[i].C = device_context.allocate_and_initialize_tensor( + options, + "C", + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + seed_shift++, + i // device_index + ); + + gemm_workspace_[i].Computed = device_context.allocate_tensor( + options, + "D", + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + + gemm_workspace_[i].Reference = device_context.allocate_tensor( + options, + "Reference", + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + } - /* Query device SM count to pass onto the kernel as an argument, where needed */ - gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount; + if (options.execution_mode != ExecutionMode::kDryRun) { + // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels + gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; + gemm_workspace_[i].arguments.batch_count = problem_.batch_count; + gemm_workspace_[i].arguments.lda = problem_.lda; + gemm_workspace_[i].arguments.ldb = problem_.ldb; + gemm_workspace_[i].arguments.ldc = problem_.ldc; + gemm_workspace_[i].arguments.ldd = problem_.ldc; + gemm_workspace_[i].arguments.batch_stride_A = gemm_workspace_[i].A->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_B = gemm_workspace_[i].B->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + gemm_workspace_[i].arguments.sm_count = options.device.properties[0].multiProcessorCount; + gemm_workspace_[i].arguments.device_index = static_cast(i); + } } // @@ -611,58 +628,69 @@ Status GemmOperationProfiler::initialize_workspace( if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { if (options.execution_mode != ExecutionMode::kDryRun) { - uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); - gemm_workspace_.host_workspace.resize(workspace_size, 0); - - workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, - &gemm_workspace_.arguments); - if (is_sparse) { - // sparse gemm get_device_workspace_size() only return device workspace size per iteration - // Needs to multiply it w/ number of iteration - workspace_size *= gemm_workspace_.problem_count; - } - gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); - - // Convert to structure sparse contents here. - if (is_sparse) { - uint8_t* profiler_workspaces[1]; - profiler_workspaces[0] = reinterpret_cast(gemm_workspace_.A->data()); - // Sparse operations have a different initialize interface. - // initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE - auto modifiable_underlying_op = const_cast(underlying_operation); - status = modifiable_underlying_op->initialize_with_profiler_workspace( - &gemm_workspace_.configuration, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data(), - profiler_workspaces, - gemm_workspace_.problem_count); - } - else { - status = underlying_operation->initialize( - &gemm_workspace_.configuration, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data()); - } + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_[i].configuration); + gemm_workspace_[i].host_workspace.resize(workspace_size, 0); + + workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_[i].configuration, + &gemm_workspace_[i].arguments); + if (is_sparse) { + // sparse gemm get_device_workspace_size() only return device workspace size per iteration + // Needs to multiply it w/ number of iteration + workspace_size *= gemm_workspace_[i].problem_count; + } + gemm_workspace_[i].device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + + // Convert to structure sparse contents here. + if (is_sparse) { + uint8_t* profiler_workspaces[1]; + profiler_workspaces[0] = reinterpret_cast(gemm_workspace_[i].A->data()); + // Sparse operations have a different initialize interface. + // initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE + auto modifiable_underlying_op = const_cast(underlying_operation); + status = modifiable_underlying_op->initialize_with_profiler_workspace( + &gemm_workspace_[i].configuration, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + profiler_workspaces, + gemm_workspace_[i].problem_count, + gemm_workspace_[i].stream); + } + else { + status = underlying_operation->initialize( + &gemm_workspace_[i].configuration, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); + } - if (status != Status::kSuccess) { - return status; - } + if (status != Status::kSuccess) { + return status; + } - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration); - gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_[i].reduction_configuration); + gemm_workspace_[i].reduction_host_workspace.resize(workspace_size, 0); - status = reduction_op_->initialize( - &gemm_workspace_.reduction_configuration, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + status = reduction_op_->initialize( + &gemm_workspace_[i].reduction_configuration, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); - if (status != Status::kSuccess) { - return status; + if (status != Status::kSuccess) { + return status; + } } } } + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaDeviceSynchronize(); + } + // // If CUTLASS is enabled, generate a result for it // @@ -698,29 +726,31 @@ bool GemmOperationProfiler::verify_cutlass( } // Initialize structure containing GEMM arguments - gemm_workspace_.arguments.A = gemm_workspace_.A->data(); - gemm_workspace_.arguments.B = gemm_workspace_.B->data(); - gemm_workspace_.arguments.C = gemm_workspace_.C->data(); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->data(); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->data(); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->data(); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].arguments.beta = problem_.beta.data(); + gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].arguments.batch_stride_A = gemm_workspace_[i].A->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_B = gemm_workspace_[i].B->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); - gemm_workspace_.arguments.beta = problem_.beta_zero.data(); - - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); - gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); - gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); - gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_[i].arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->data(); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_[i].reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } } // @@ -737,27 +767,33 @@ bool GemmOperationProfiler::verify_cutlass( } } - results_.back().status = underlying_operation->run( - &gemm_workspace_.arguments, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data()); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); - if (results_.back().status != Status::kSuccess) { - results_.back().disposition = Disposition::kFailed; - return false; - } - - // Run parallel reduction kernel for parallel split_k_mode - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - results_.back().status = reduction_op_->run( - &gemm_workspace_.reduction_arguments, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + results_.back().status = underlying_operation->run( + &gemm_workspace_[i].arguments, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); if (results_.back().status != Status::kSuccess) { results_.back().disposition = Disposition::kFailed; return false; } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + results_.back().status = reduction_op_->run( + &gemm_workspace_[i].reduction_arguments, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + } } cudaError_t result = cudaDeviceSynchronize(); @@ -784,13 +820,17 @@ bool GemmOperationProfiler::verify_cutlass( if (cublas_satisfies(gemm_desc) == Status::kSuccess) { // call cublas verification if supported - verify_with_cublas_( - options, - report, - device_context, - operation, - problem_space, - problem); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + verify_with_cublas_( + options, + report, + device_context, + operation, + problem_space, + problem, + gemm_workspace_[i]); + } } else { @@ -852,7 +892,8 @@ bool GemmOperationProfiler::verify_with_cublas_( DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) { + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace_) { #if CUTLASS_ENABLE_CUBLAS @@ -983,115 +1024,119 @@ bool GemmOperationProfiler::verify_with_reference_( continue; } - void *ptr_A = gemm_workspace_.A->data(); - void *ptr_B = gemm_workspace_.B->data(); - void *ptr_C = gemm_workspace_.C->data(); - void *ptr_D = gemm_workspace_.Reference->data(); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); - // To support the host-side reference, conditionally allocate and - // copy tensors to host memory. - std::vector host_data_A; - std::vector host_data_B; - std::vector host_data_C; - std::vector host_data_D; + void *ptr_A = gemm_workspace_[i].A->data(); + void *ptr_B = gemm_workspace_[i].B->data(); + void *ptr_C = gemm_workspace_[i].C->data(); + void *ptr_D = gemm_workspace_[i].Reference->data(); - if (provider == library::Provider::kReferenceHost) { + // To support the host-side reference, conditionally allocate and + // copy tensors to host memory. + std::vector host_data_A; + std::vector host_data_B; + std::vector host_data_C; + std::vector host_data_D; - host_data_A.resize(gemm_workspace_.A->bytes()); - ptr_A = host_data_A.data(); - gemm_workspace_.A->copy_to_host(ptr_A); + if (provider == library::Provider::kReferenceHost) { - host_data_B.resize(gemm_workspace_.B->bytes()); - ptr_B = host_data_B.data(); - gemm_workspace_.B->copy_to_host(ptr_B); + host_data_A.resize(gemm_workspace_[i].A->bytes()); + ptr_A = host_data_A.data(); + gemm_workspace_[i].A->copy_to_host(ptr_A); - host_data_C.resize(gemm_workspace_.C->bytes()); - ptr_C = host_data_C.data(); - gemm_workspace_.C->copy_to_host(ptr_C); + host_data_B.resize(gemm_workspace_[i].B->bytes()); + ptr_B = host_data_B.data(); + gemm_workspace_[i].B->copy_to_host(ptr_B); - host_data_D.resize(gemm_workspace_.Reference->bytes()); - ptr_D = host_data_D.data(); - } + host_data_C.resize(gemm_workspace_[i].C->bytes()); + ptr_C = host_data_C.data(); + gemm_workspace_[i].C->copy_to_host(ptr_C); - // - // Launch - // + host_data_D.resize(gemm_workspace_[i].Reference->bytes()); + ptr_D = host_data_D.data(); + } - library::Handle handle; + // + // Launch + // - handle.set_provider(provider); + library::Handle handle; - Status status = handle.gemm_universal( - problem_.mode, - gemm_workspace_.configuration.problem_size.m(), - gemm_workspace_.configuration.problem_size.n(), - gemm_workspace_.configuration.problem_size.k(), - gemm_desc.tile_description.math_instruction.element_accumulator, - gemm_desc.element_epilogue, + handle.set_provider(provider); - problem_.alpha.data(), + Status status = handle.gemm_universal( + problem_.mode, + gemm_workspace_[i].configuration.problem_size.m(), + gemm_workspace_[i].configuration.problem_size.n(), + gemm_workspace_[i].configuration.problem_size.k(), + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, - element_A, - gemm_desc.A.layout, - gemm_desc.transform_A, - ptr_A, - int(gemm_workspace_.configuration.lda), + problem_.alpha.data(), - element_B, - gemm_desc.B.layout, - gemm_desc.transform_B, - ptr_B, - int(gemm_workspace_.configuration.ldb), + element_A, + gemm_desc.A.layout, + gemm_desc.transform_A, + ptr_A, + int(gemm_workspace_[i].configuration.lda), - problem_.beta.data(), + element_B, + gemm_desc.B.layout, + gemm_desc.transform_B, + ptr_B, + int(gemm_workspace_[i].configuration.ldb), - gemm_desc.C.element, - gemm_desc.C.layout, - ptr_C, - int(gemm_workspace_.configuration.ldc), + problem_.beta.data(), - gemm_desc.D.element, - gemm_desc.D.layout, - ptr_D, - int(gemm_workspace_.configuration.ldd), + gemm_desc.C.element, + gemm_desc.C.layout, + ptr_C, + int(gemm_workspace_[i].configuration.ldc), - gemm_workspace_.configuration.batch_count, - gemm_workspace_.A->batch_stride(), - gemm_workspace_.B->batch_stride(), - gemm_workspace_.C->batch_stride(), - gemm_workspace_.Reference->batch_stride()); + gemm_desc.D.element, + gemm_desc.D.layout, + ptr_D, + int(gemm_workspace_[i].configuration.ldd), - if (status != Status::kSuccess) { - results_.back().verification_map[provider] = Disposition::kNotRun; - continue; - } - results_.back().status = status; + gemm_workspace_[i].configuration.batch_count, + gemm_workspace_[i].A->batch_stride(), + gemm_workspace_[i].B->batch_stride(), + gemm_workspace_[i].C->batch_stride(), + gemm_workspace_[i].Reference->batch_stride()); - if (provider == library::Provider::kReferenceHost) { - gemm_workspace_.Reference->copy_from_host(ptr_D); - } - - // - // Verify results - // + if (status != Status::kSuccess) { + results_.back().verification_map[provider] = Disposition::kNotRun; + continue; + } + results_.back().status = status; - results_.back().verification_map[provider] = compare_tensors( - options, - *gemm_workspace_.Computed, - *gemm_workspace_.Reference, - gemm_workspace_.Computed->batch_stride() - ); + if (provider == library::Provider::kReferenceHost) { + gemm_workspace_[i].Reference->copy_from_host(ptr_D); + } - // Save workspace if incorrect - if (options.verification.save_workspace == SaveWorkspace::kIncorrect && - results_.back().verification_map[provider] == Disposition::kIncorrect) { + // + // Verify results + // - save_workspace( - device_context, + results_.back().verification_map[provider] = compare_tensors( options, - gemm_desc, - library::Provider::kCUTLASS, - provider); + *gemm_workspace_[i].Computed, + *gemm_workspace_[i].Reference, + gemm_workspace_[i].Computed->batch_stride() + ); + + // Save workspace if incorrect + if (options.verification.save_workspace == SaveWorkspace::kIncorrect && + results_.back().verification_map[provider] == Disposition::kIncorrect) { + + save_workspace( + device_context, + options, + gemm_desc, + library::Provider::kCUTLASS, + provider); + } } } @@ -1100,6 +1145,18 @@ bool GemmOperationProfiler::verify_with_reference_( ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace { +extern "C" { + __global__ void delay(cuda::atomic const* release) { + while (release->load(cuda::memory_order_acquire) != true) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + __nanosleep(100); +#endif + } + } +} +} + /// Measures performance results bool GemmOperationProfiler::profile( Options const &options, @@ -1111,39 +1168,41 @@ bool GemmOperationProfiler::profile( if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { - // Initialize structure containing GEMM arguments - gemm_workspace_.arguments.A = gemm_workspace_.A->data(); - gemm_workspace_.arguments.B = gemm_workspace_.B->data(); - gemm_workspace_.arguments.C = gemm_workspace_.C->data(); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + // Initialize structure containing GEMM arguments + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->data(); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->data(); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->data(); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].arguments.beta = problem_.beta.data(); + gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].arguments.batch_stride_A = gemm_workspace_[i].A->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_B = gemm_workspace_[i].B->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); - gemm_workspace_.arguments.beta = problem_.beta_zero.data(); - - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); - gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); - gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); - gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_[i].arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->data(); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_[i].reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } } results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, - &gemm_workspace_.arguments, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data() + nullptr, + nullptr, + nullptr ); } return true; @@ -1153,14 +1212,22 @@ bool GemmOperationProfiler::profile( /// Method to profile a CUTLASS Operation Status GemmOperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace) { - - GpuTimer timer; + void *, + void *, + void *) { + + cuda::atomic *release; + cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable); + release->store(false, cuda::memory_order_release); + + std::vector timer; + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + timer.emplace_back(); + } // initialize gemm underlying operation to handle parallel reduction library::Operation const * underlying_operation = operation; @@ -1182,110 +1249,158 @@ Status GemmOperationProfiler::profile_cutlass_( Status status; - for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { - - int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; - - gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); - gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); - gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); - - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); - } + std::vector graphs; + graphs.resize(gemm_workspace_.size()); + std::vector graphExecs; + graphExecs.resize(gemm_workspace_.size()); + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaStreamBeginCapture(gemm_workspace_[i].stream, cudaStreamCaptureModeGlobal); + // Halt execution until all GPUs are ready to precede. + // It allows the CPU to trigger the GPUs all start at the same time. + delay<<<1, 1, 0, gemm_workspace_[i].stream>>>(release); + for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { + int problem_idx = (iteration % gemm_workspace_[i].problem_count) * problem_.batch_count; + + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->batch_data(problem_idx); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->batch_data(problem_idx); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->batch_data(problem_idx); - // Execute the CUTLASS operation - status = underlying_operation->run( - &gemm_workspace_.arguments, - host_workspace, - device_workspace); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); - if (status != Status::kSuccess) { - return status; - } + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->batch_data(problem_idx); + } - // Run parallel reduction kernel for parallel split_k_mode - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - status = reduction_op_->run( - &gemm_workspace_.reduction_arguments, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + // Execute the CUTLASS operation + status = underlying_operation->run( + &gemm_workspace_[i].arguments, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); if (status != Status::kSuccess) { return status; } - } - } - // - // Initialize GPU timer - // + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_[i].reduction_arguments, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); - timer.start(); + if (status != Status::kSuccess) { + return status; + } + } + } - // - // Profiling loop - // + // + // Initialize GPU timer + // - int Iterations = options.profiling.iterations; + timer[i].start(gemm_workspace_[i].stream, cudaEventRecordExternal); - int iteration = 0; - for (; iteration < Iterations; ++iteration) { + // + // Profiling loop + // - // Iterate over copies of the problem in memory - int workspace_idx = options.profiling.warmup_iterations + iteration; - int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count; + int Iterations = options.profiling.iterations; - gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); - gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); - gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); + int iteration = 0; - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + for (; iteration < Iterations; ++iteration) { + // Iterate over copies of the problem in memory + int workspace_idx = options.profiling.warmup_iterations + iteration; + int problem_idx = (workspace_idx % gemm_workspace_[i].problem_count) * problem_.batch_count; - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); - } + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->batch_data(problem_idx); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->batch_data(problem_idx); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->batch_data(problem_idx); - status = underlying_operation->run( - arguments, - host_workspace, - device_workspace); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); - if (status != Status::kSuccess) { - return status; - } + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->batch_data(problem_idx); + } - // Run parallel reduction kernel for parallel split_k_mode - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - status = reduction_op_->run( - &gemm_workspace_.reduction_arguments, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + status = underlying_operation->run( + &gemm_workspace_[i].arguments, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); if (status != Status::kSuccess) { return status; } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_[i].reduction_arguments, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); + + if (status != Status::kSuccess) { + return status; + } + } } + timer[i].stop(gemm_workspace_[i].stream, cudaEventRecordExternal); + cudaStreamEndCapture(gemm_workspace_[i].stream, &graphs[i]); + cudaGraphInstantiate(&graphExecs[i], graphs[i], nullptr, nullptr, 0); + } + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaGraphLaunch(graphExecs[i], gemm_workspace_[i].stream); } // // Wait for completion // - timer.stop_and_wait(); + release->store(true, cuda::memory_order_release); + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaStreamSynchronize(gemm_workspace_[i].stream); + } // // Update performance result // - runtime = timer.duration(iteration); + + result.runtime = 0; + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + result.runtime_vector[i] = timer[i].duration(options.profiling.iterations); + result.runtime += result.runtime_vector[i]; + } + result.runtime /= static_cast(gemm_workspace_.size()); + + cudaFreeHost(release); + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaGraphExecDestroy(graphExecs[i]); + cudaGraphDestroy(graphs[i]); + } + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(gemm_workspace_.size() - i - 1)); + timer.pop_back(); + } return status; } diff --git a/tools/profiler/src/gpu_timer.cpp b/tools/profiler/src/gpu_timer.cpp index cf03db1894..cd0e4df09d 100644 --- a/tools/profiler/src/gpu_timer.cpp +++ b/tools/profiler/src/gpu_timer.cpp @@ -33,9 +33,11 @@ */ #include +#include #include "cutlass/profiler/gpu_timer.h" + namespace cutlass { namespace profiler { @@ -52,32 +54,39 @@ GpuTimer::GpuTimer() { } } +GpuTimer::GpuTimer(GpuTimer&& gpu_timer) noexcept { + memcpy(events, gpu_timer.events, sizeof(events)); + memset(gpu_timer.events, 0, sizeof(gpu_timer.events)); +} + GpuTimer::~GpuTimer() { - for (auto & event : events) { - cudaEventDestroy(event); + for (const auto & event : events) { + if (event != nullptr) { + cudaEventDestroy(event); + } } } -/// Records a start event in the stream -void GpuTimer::start(cudaStream_t stream) { - cudaError_t result = cudaEventRecord(events[0], stream); +/// Records a start event in the stream, the flag is for cudaEventRecordWithFlags +void GpuTimer::start(cudaStream_t stream, const unsigned int flag) { + cudaError_t result = cudaEventRecordWithFlags(events[0], stream, flag); if (result != cudaSuccess) { throw std::runtime_error("Failed to record start event."); } } -/// Records a stop event in the stream -void GpuTimer::stop(cudaStream_t stream) { -cudaError_t result = cudaEventRecord(events[1], stream); +/// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags +void GpuTimer::stop(cudaStream_t stream, const unsigned int flag) { +cudaError_t result = cudaEventRecordWithFlags(events[1], stream, flag); if (result != cudaSuccess) { throw std::runtime_error("Failed to record stop event."); } } -/// Records a stop event in the stream and synchronizes on the stream -void GpuTimer::stop_and_wait(cudaStream_t stream) { +/// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags +void GpuTimer::stop_and_wait(cudaStream_t stream, const unsigned int flag) { - stop(stream); + stop(stream, flag); cudaError_t result; if (stream) { diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index ce1ebb21f4..4d5c9d0973 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -658,7 +658,7 @@ void OperationProfiler::save_workspace( /// Method to profile a CUTLASS Operation Status OperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, @@ -726,7 +726,7 @@ Status OperationProfiler::profile_cutlass_( // Update performance result // - runtime = timer.duration(iteration); + result.runtime = timer.duration(iteration); return status; } diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index f1c1d7a77a..59368e9bad 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -307,12 +307,6 @@ void Options::Initialization::get_distribution( {0, 0} }; - // Initalize pnz values to a default value of 100% - dist.gaussian.pnz = 1.0; - dist.gaussian.pnzA = 1.0; - dist.gaussian.pnzB = 1.0; - dist.gaussian.pnzC = 1.0; - using KeyValueVector = std::vector >; KeyValueVector values; @@ -330,6 +324,25 @@ void Options::Initialization::get_distribution( ++it; } + // Default initialization + switch (dist.kind) { + case cutlass::Distribution::Uniform: + dist.set_uniform(-4/*min*/, 4/*max*/); + break; + case cutlass::Distribution::Gaussian: + dist.set_gaussian(0/*mean*/, 4/*stddev*/); + break; + case cutlass::Distribution::Identity: + dist.set_identity(); + break; + case cutlass::Distribution::Sequential: + dist.set_sequential(0/*start*/, 4/*delta*/); + break; + default: + dist.set_uniform(-4/*min*/, 4/*max*/); + return; + } + // Subsequent key-value pairs update the named field of the distribution struct. for (; it != values.end(); ++it) { // Integer scaling factor - if < 0, no integer rounding is performed. diff --git a/tools/profiler/src/performance_report.cpp b/tools/profiler/src/performance_report.cpp index 12855a2fe8..1d04f48f0e 100644 --- a/tools/profiler/src/performance_report.cpp +++ b/tools/profiler/src/performance_report.cpp @@ -337,7 +337,15 @@ std::ostream & PerformanceReport::print_csv_header_( << ",Bytes" << ",Flops" << ",Flops/Byte" - << ",Runtime" + << ",Runtime"; + + if (options_.device.devices.size() > 1) { + for (size_t i = 0; i < options_.device.devices.size(); i++) { + out << ",Runtime_" << i; + } + } + + out << ",GB/s" << ",GFLOPs" ; @@ -376,6 +384,16 @@ std::ostream & PerformanceReport::print_result_csv_( << "," << result.flops / result.bytes << "," << result.runtime; + if (options_.device.devices.size() > 1) { + if (result.runtime_vector.size() != options_.device.devices.size()) { + throw std::runtime_error("Runtime vector size mismatch"); + } + + for (const auto runtime : result.runtime_vector) { + out << "," << runtime; + } + } + if (result.good()) { out diff --git a/tools/profiler/src/rank_2k_operation_profiler.cu b/tools/profiler/src/rank_2k_operation_profiler.cu index df8ad40f3f..4b547a3e1f 100644 --- a/tools/profiler/src/rank_2k_operation_profiler.cu +++ b/tools/profiler/src/rank_2k_operation_profiler.cu @@ -733,7 +733,7 @@ bool Rank2KOperationProfiler::profile( rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &rank_k_workspace_.arguments, diff --git a/tools/profiler/src/rank_k_operation_profiler.cu b/tools/profiler/src/rank_k_operation_profiler.cu index 49fe54c846..52613b8ebd 100644 --- a/tools/profiler/src/rank_k_operation_profiler.cu +++ b/tools/profiler/src/rank_k_operation_profiler.cu @@ -718,7 +718,7 @@ bool RankKOperationProfiler::profile( rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &rank_k_workspace_.arguments, diff --git a/tools/profiler/src/sparse_gemm_operation_profiler.cu b/tools/profiler/src/sparse_gemm_operation_profiler.cu index 939608f8bb..ec14a33236 100644 --- a/tools/profiler/src/sparse_gemm_operation_profiler.cu +++ b/tools/profiler/src/sparse_gemm_operation_profiler.cu @@ -578,7 +578,7 @@ bool SparseGemmOperationProfiler::profile( gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &gemm_workspace_.arguments, diff --git a/tools/profiler/src/symm_operation_profiler.cu b/tools/profiler/src/symm_operation_profiler.cu index 59fcf0f147..80f645e75b 100644 --- a/tools/profiler/src/symm_operation_profiler.cu +++ b/tools/profiler/src/symm_operation_profiler.cu @@ -771,7 +771,7 @@ bool SymmOperationProfiler::profile( symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &symm_workspace_.arguments, diff --git a/tools/profiler/src/trmm_operation_profiler.cu b/tools/profiler/src/trmm_operation_profiler.cu index 5983b01168..9d3b4db6fb 100644 --- a/tools/profiler/src/trmm_operation_profiler.cu +++ b/tools/profiler/src/trmm_operation_profiler.cu @@ -709,7 +709,7 @@ bool TrmmOperationProfiler::profile( trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &trmm_workspace_.arguments, diff --git a/tools/util/include/cutlass/util/device_dump.h b/tools/util/include/cutlass/util/device_dump.h index cd0161603c..bb20e9b714 100644 --- a/tools/util/include/cutlass/util/device_dump.h +++ b/tools/util/include/cutlass/util/device_dump.h @@ -31,7 +31,7 @@ #pragma once -#include +#include #include "cutlass/cutlass.h" /** diff --git a/tools/util/include/cutlass/util/device_groupnorm.h b/tools/util/include/cutlass/util/device_groupnorm.h index 07f56c7154..5fc93a11b4 100644 --- a/tools/util/include/cutlass/util/device_groupnorm.h +++ b/tools/util/include/cutlass/util/device_groupnorm.h @@ -42,7 +42,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "device_utils.h" -#include +#include namespace cutlass { diff --git a/tools/util/include/cutlass/util/device_layernorm.h b/tools/util/include/cutlass/util/device_layernorm.h index 0ee58a7f7f..7708c3eba4 100644 --- a/tools/util/include/cutlass/util/device_layernorm.h +++ b/tools/util/include/cutlass/util/device_layernorm.h @@ -42,7 +42,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "device_utils.h" -#include +#include namespace cutlass { diff --git a/tools/util/include/cutlass/util/device_nhwc_pooling.h b/tools/util/include/cutlass/util/device_nhwc_pooling.h index 05fe5584a1..cce452d9eb 100644 --- a/tools/util/include/cutlass/util/device_nhwc_pooling.h +++ b/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -42,7 +42,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "device_utils.h" -#include +#include namespace cutlass { diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h index c4542eff06..44a1c08487 100644 --- a/tools/util/include/cutlass/util/device_rmsnorm.h +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -37,7 +37,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "cutlass/util/device_utils.h" -#include +#include namespace cutlass { @@ -165,12 +165,12 @@ void rmsnorm(cutlass::MatrixCoord tensor_size, dim3 grid(m); if (n % 8 == 0 && std::is_same::value) { - dim3 block(min(1024, (n / 8 + 31) / 32 * 32)); + dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); rmsnorm_twoPassAlgo_e8<<>>( (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); } else { - dim3 block(min(1024, ((n + 31)/32 + 31)/32*32)); + dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); rmsnorm_twoPassAlgo_e1<<>>( output, input, weight, m, n, epsilon); diff --git a/tools/util/include/cutlass/util/device_utils.h b/tools/util/include/cutlass/util/device_utils.h index 3ec078c886..7a8378fc2d 100644 --- a/tools/util/include/cutlass/util/device_utils.h +++ b/tools/util/include/cutlass/util/device_utils.h @@ -36,7 +36,7 @@ #pragma once #include -#include +#include #define FINAL_MASK 0xffffffff struct half4 { diff --git a/tools/util/include/cutlass/util/distribution.h b/tools/util/include/cutlass/util/distribution.h index 649a573603..086e033a90 100644 --- a/tools/util/include/cutlass/util/distribution.h +++ b/tools/util/include/cutlass/util/distribution.h @@ -100,6 +100,9 @@ struct Distribution { gaussian.mean = _mean; gaussian.stddev = _stddev; gaussian.pnz = _pnz; + gaussian.pnzA = _pnz; + gaussian.pnzB = _pnz; + gaussian.pnzC = _pnz; int_scale = _int_scale; return *this; } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 13aedf14d1..059076d957 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -138,8 +138,8 @@ struct RandomGaussianFunc { int_scale(int_scale_), exclude_zero(exclude_zero_) { - float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits - float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -175,8 +175,8 @@ struct RandomGaussianFunc { Element result; if (params.int_scale >= 0) { - rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); - result = Element(IntType(rnd * params.float_scale_down)); + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); } else { result = Element(rnd); @@ -237,7 +237,6 @@ struct RandomGaussianFunc> { exclude_zero(exclude_zero_) { float_scale_up = FloatType(IntType(1) << int_scale); - float_scale_up += FloatType(0.5) * float_scale_up; float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -276,8 +275,8 @@ struct RandomGaussianFunc> { Element result; if (params.int_scale >= 0) { - rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); - rnd_i = FloatType(IntType(rnd_i * params.float_scale_down)); + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); result = { Real(rnd_r * params.float_scale_down), @@ -482,8 +481,8 @@ struct RandomUniformFunc { pnan(pnan_), exclude_zero(exclude_zero_) { - float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits - float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); // Handle cases where min = 0 or max = 0 for excluding zeros if (exclude_zero >= 0) { @@ -535,8 +534,8 @@ struct RandomUniformFunc { Element result; if (params.int_scale >= 0) { - rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); - result = Element(IntType(rnd * params.float_scale_down)); + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); } else { result = Element(rnd); @@ -612,7 +611,6 @@ struct RandomUniformFunc> { exclude_zero(exclude_zero_) { float_scale_up = FloatType(IntType(1) << int_scale); - float_scale_up += FloatType(0.5) * float_scale_up; float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); // Handle cases where min = 0 or max = 0 for excluding zeros @@ -668,8 +666,8 @@ struct RandomUniformFunc> { Element result; if (params.int_scale >= 0) { - rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); - rnd_i = FloatType(IntType(rnd_i * params.float_scale_up)); + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); result = { Real(rnd_r * params.float_scale_down), diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index f6984fb2ba..184d773783 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -281,7 +281,6 @@ void gett_epilogue( cute::is_same_v; constexpr bool IsClamp = cute::is_same_v>; - constexpr bool IsBackpropFusion = cute::is_same_v> or cute::is_same_v>; diff --git a/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h index 090019c100..9e1ac76cda 100644 --- a/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h +++ b/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h @@ -41,7 +41,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" -#include +#include namespace cutlass { namespace reference { diff --git a/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/tools/util/include/cutlass/util/reference/host/rank_k_complex.h index ef44270a31..6f9d5dc40f 100644 --- a/tools/util/include/cutlass/util/reference/host/rank_k_complex.h +++ b/tools/util/include/cutlass/util/reference/host/rank_k_complex.h @@ -41,7 +41,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" -#include +#include namespace cutlass { namespace reference { diff --git a/tools/util/include/cutlass/util/reference/host/symm_complex.h b/tools/util/include/cutlass/util/reference/host/symm_complex.h index 2618feaa70..7a55bb39c6 100644 --- a/tools/util/include/cutlass/util/reference/host/symm_complex.h +++ b/tools/util/include/cutlass/util/reference/host/symm_complex.h @@ -41,7 +41,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" -#include +#include namespace cutlass { namespace reference { diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index b9f0c84d9a..85c70e41c3 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -269,8 +269,8 @@ struct RandomGaussianFunc > { // Sample from the Gaussian distribution for a nonzero element if (bernoulli_result) { if (int_scale >= 0) { - rnd[0] = double(int(rnd[0] * double(1 << int_scale))); - rnd[1] = double(int(rnd[1] * double(1 << int_scale))); + rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); + rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); reals[0] = from_real(rnd[0] / double(1 << int_scale)); reals[1] = from_real(rnd[1] / double(1 << int_scale)); } @@ -348,10 +348,10 @@ struct RandomGaussianFunc > { // Sample from the Gaussian distribution for a nonzero element if (bernoulli_result) { if (int_scale >= 0) { - rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); - rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); - rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); - rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); + rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); reals[0] = from_real(rnd1[0] / double(1 << int_scale)); reals[1] = from_real(rnd1[1] / double(1 << int_scale)); @@ -725,7 +725,7 @@ struct RandomUniformFunc > { // testing if (int_scale >= 0) { - rnd = double(int(rnd * double(1 << int_scale))); + rnd = double(std::llround(rnd * double(1 << int_scale))); reals[i] = from_real(Real(rnd / double(1 << int_scale))); } else { @@ -808,7 +808,7 @@ struct RandomUniformFunc > { // testing if (int_scale >= 0) { - rnd = double(int(rnd * double(1 << int_scale))); + rnd = double(std::llround(rnd * double(1 << int_scale))); reals[i] = from_real(Real(rnd / double(1 << int_scale))); } else { diff --git a/tools/util/include/cutlass/util/type_traits.h b/tools/util/include/cutlass/util/type_traits.h index 8379957aea..dec3168ea8 100644 --- a/tools/util/include/cutlass/util/type_traits.h +++ b/tools/util/include/cutlass/util/type_traits.h @@ -36,7 +36,7 @@ #include #include -#include +#include #include "cutlass/numeric_types.h" #include "cutlass/complex.h" From 87eaa6965a23b210faad62b386009546c311e37e Mon Sep 17 00:00:00 2001 From: Haicheng Wu Date: Tue, 24 Dec 2024 22:11:53 -0800 Subject: [PATCH 2/2] doc and swap stuff --- CHANGELOG.md | 7 ++++++- examples/55_hopper_mixed_dtype_gemm/README.md | 7 ++++++- .../mixed_dtype_utils.hpp | 20 +++++++++---------- .../packed_scale.hpp | 4 +++- .../detail/collective/mixed_input_utils.hpp | 11 +++++----- ...sm90_visitor_store_tma_warpspecialized.hpp | 2 +- include/cutlass/fast_math.h | 9 ++------- .../kernel/gemm_grouped_problem_visitor.h | 2 +- .../kernel/rank_2k_grouped_problem_visitor.h | 2 +- .../library/include/cutlass/library/library.h | 1 - 10 files changed, 36 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d0140a586..1ba870eba2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,12 @@ + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. -- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md). + + Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. + + Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading. + + Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`. + + Other general optimizations. +- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead. - [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). - [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). - [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 48eca35c2d..ecb4f41c97 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -9,12 +9,17 @@ This first version only supports mixed type GEMMs using TMA. ## Performance -While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type. +While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type. The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now. + +Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details. + + We are currently optimizing the following cases: 1. Memory bound cases for all types +2. `fp8 x {int2, uint2}` case ## Limitations diff --git a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp index fdb31316e7..55de3fabb3 100644 --- a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -151,16 +151,16 @@ void mixed_dtype_profiling( runtimes.reserve(options.iterations); for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { - cudaEventRecord(start); - CUTLASS_CHECK(gemm.run()); - cudaEventRecord(stop); - cudaEventSynchronize(stop); - - if (iter >= options.warmup) { - float milliseconds = 0; - cudaEventElapsedTime(&milliseconds, start, stop); - runtimes.push_back(milliseconds); - } + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } } cudaEventDestroy(start); diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp index 02e257c3fe..bd71e9cf28 100644 --- a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -33,6 +33,9 @@ #include + +#include "cutlass/util/device_memory.h" +#include "cutlass/integer_subbyte.h" #include "cutlass/float8.h" #include "cutlass/util/reference/device/tensor_fill.h" @@ -197,7 +200,6 @@ bool initialize_packed_scale( { cutlass::packed_scale_t tmp(data_in[i]); data_out[i] = reinterpret_cast const&>(tmp); - // std::cout << data_in[i] << ":" << std::hex << static_cast(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast((-data_in[i]).storage) << std::endl; } try { block_out.copy_from_host(data_out.data()); diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index c175538efa..c740eb98b2 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -237,7 +237,7 @@ struct LayoutAwareConvertImpl< } }; -// Specialization for UINT4 -> FPF16 with [02461357] value order +// Specialization for UINT4 -> FP16 with [02461357] value order template <> struct LayoutAwareConvertImpl< cutlass::uint4b_t, @@ -804,14 +804,15 @@ struct MixedInputUtils { { auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; asm volatile( "{\n" - " and .b32 %0, %2, %4 ;\n" \ - " and .b32 %1, %3, %5 ;\n" \ + " lop3 .b32 %0, %2, %4, %5, %6;\n" \ + " xor .b32 %1, %3, %5; \n" \ "}\n" : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) - : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0x7F7F7F00), "n"(0x7F7F7F7F) - ); + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) + ); } } CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 9c87be0809..83cfc030df 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -937,7 +937,7 @@ struct Sm90RowReduction { for (int v = 0; v < size(frg_A); ++v) { // Step1: swap if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second - swap(frg_A(v), frg_B(v)); + cutlass::swap(frg_A(v), frg_B(v)); } // Step2: shuffle diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index fa3873c5e7..4ca8e113db 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -38,7 +38,7 @@ #include #include #endif - +#include #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/uint128.h" @@ -54,12 +54,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTLASS_HOST_DEVICE void swap(T &lhs, T &rhs) { - T tmp = lhs; - lhs = rhs; - rhs = tmp; -} +using ::cuda::std::swap; /****************************************************************************** * Static math utilities diff --git a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h index 304f23e730..1c4411bd55 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h @@ -68,7 +68,7 @@ struct GemmGroupedProblemSizeHelper { CUTLASS_HOST_DEVICE static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { if (kTransposed) { - swap(problem.m(), problem.n()); + cutlass::swap(problem.m(), problem.n()); } } diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h index 2e31c7783a..054d2a73df 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h @@ -357,7 +357,7 @@ struct Rank2KGroupedProblemVisitor : public GroupedProblemVisitor< int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); if (kFillModeC == cutlass::FillMode::kUpper) { - swap(macro_row, macro_col); + cutlass::swap(macro_row, macro_col); } int32_t row = OffsetHelper::macro_row_to_row(macro_row, threadblock_id); diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index a3af54ba26..19812d4b94 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -103,7 +103,6 @@ class Operation { void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; - // Originally designed for metadata, but should be useful for FP8/6/4 too. virtual Status initialize_with_profiler_workspace( void const *configuration, void *host_workspace,