From d40c848bbb97a5cc89fecafaae222a8587e9f8e1 Mon Sep 17 00:00:00 2001 From: Mihir Awatramani Date: Fri, 24 Jan 2025 16:18:29 -0800 Subject: [PATCH] CUTLASS 3.8 Release --- ACTIVE_DEVELOPERS.md | 73 + CHANGELOG.md | 60 +- CMakeLists.txt | 20 + CONTRIBUTORS.md | 87 - README.md | 245 +- .../48_hopper_warp_specialized_gemm.cu | 8 + .../49_collective_builder.cu | 9 + .../50_hopper_gemm_with_epilogue_swizzle.cu | 9 + .../52_hopper_gather_scatter_fusion.cu | 7 + .../53_hopper_gemm_permute.cu | 7 + .../54_hopper_fp8_warp_specialized_gemm.cu | 7 + .../55_hopper_int4_bf16_gemm.cu | 7 + .../55_hopper_int4_fp8_gemm.cu | 7 + .../55_hopper_mixed_dtype_gemm.cu | 7 + examples/55_hopper_mixed_dtype_gemm/README.md | 6 +- .../mixed_dtype_utils.hpp | 20 +- .../packed_scale.hpp | 3 +- .../56_hopper_ptr_array_batched_gemm.cu | 7 + .../57_hopper_grouped_gemm.cu | 7 + .../61_hopper_gemm_with_topk_and_softmax.cu | 7 + .../62_hopper_sparse_gemm.cu | 8 + .../63_hopper_gemm_with_weight_prefetch.cu | 7 + .../65_distributed_gemm.cu | 6 +- .../70_blackwell_fp16_gemm.cu | 483 ++ .../70_blackwell_fp8_gemm.cu | 671 ++ examples/70_blackwell_gemm/CMakeLists.txt | 41 + ..._blackwell_gemm_with_collective_builder.cu | 570 ++ .../CMakeLists.txt | 35 + .../72a_blackwell_nvfp4_bf16_gemm.cu | 544 ++ .../72b_blackwell_nvfp4_nvfp4_gemm.cu | 594 ++ .../72c_blackwell_mixed_mxfp8_bf16_gemm.cu | 545 ++ .../CMakeLists.txt | 46 + .../CMakeLists.txt | 36 + .../blackwell_gemm_preferred_cluster.cu | 541 ++ .../74_blackwell_gemm_streamk/CMakeLists.txt | 37 + .../blackwell_gemm_streamk.cu | 592 ++ .../75_blackwell_grouped_gemm.cu | 813 ++ .../75_blackwell_grouped_gemm_block_scaled.cu | 953 +++ .../75_blackwell_grouped_gemm/CMakeLists.txt | 88 + .../76_blackwell_conv_dgrad.cu | 534 ++ .../76_blackwell_conv_fprop.cu | 534 ++ .../76_blackwell_conv_wgrad.cu | 530 ++ examples/76_blackwell_conv/CMakeLists.txt | 46 + .../77_blackwell_fmha/77_blackwell_fmha.cu | 990 +++ .../77_blackwell_fmha_gen.cu | 832 ++ examples/77_blackwell_fmha/CMakeLists.txt | 105 + examples/77_blackwell_fmha/README.md | 23 + .../collective/fmha_common.hpp | 127 + .../collective/fmha_fusion.hpp | 254 + ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 200 + ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 1102 +++ ...m100_fmha_gen_epilogue_warpspecialized.hpp | 94 + ...m100_fmha_gen_mainloop_warpspecialized.hpp | 1116 +++ ...m100_fmha_load_cpasync_warpspecialized.hpp | 395 + .../sm100_fmha_load_tma_warpspecialized.hpp | 316 + examples/77_blackwell_fmha/device/fmha.hpp | 276 + .../77_blackwell_fmha/kernel/fmha_options.hpp | 85 + .../kernel/fmha_tile_scheduler.hpp | 162 + ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 519 ++ .../sm100_fmha_gen_kernel_warpspecialized.hpp | 576 ++ .../reference/fmha_fwd_gen_reference.hpp | 187 + .../reference/fmha_fwd_reference.hpp | 163 + .../reference/reference_abs_error.hpp | 180 + examples/CMakeLists.txt | 8 + examples/README.md | 293 + include/cute/arch/cluster_sm100.hpp | 108 + include/cute/arch/config.hpp | 39 + include/cute/arch/copy_sm100.hpp | 7567 +++++++++++++++++ include/cute/arch/copy_sm100_tma.hpp | 664 ++ include/cute/arch/copy_sm90_desc.hpp | 38 + include/cute/arch/copy_sm90_tma.hpp | 27 + include/cute/arch/mma_sm100.hpp | 42 + include/cute/arch/mma_sm100_desc.hpp | 652 ++ include/cute/arch/mma_sm100_umma.hpp | 1074 +++ include/cute/arch/simd_sm100.hpp | 96 + include/cute/arch/tmem_allocator_sm100.hpp | 168 + include/cute/atom/copy_atom.hpp | 19 + include/cute/atom/copy_traits_sm100.hpp | 3797 +++++++++ .../cute/atom/copy_traits_sm100_im2col.hpp | 488 ++ include/cute/atom/copy_traits_sm100_tma.hpp | 487 ++ .../atom/copy_traits_sm90_tma_swizzle.hpp | 23 + include/cute/atom/mma_atom.hpp | 9 + include/cute/atom/mma_traits_sm100.hpp | 2425 ++++++ include/cute/atom/partitioner.hpp | 109 + include/cute/container/tuple.hpp | 6 +- include/cute/numeric/int.hpp | 2 + include/cute/numeric/numeric_types.hpp | 45 + include/cute/pointer.hpp | 105 + include/cute/tensor_zip.hpp | 3 + include/cute/util/print.hpp | 5 + include/cutlass/arch/arch.h | 7 + include/cutlass/arch/barrier.h | 165 + include/cutlass/arch/config.h | 43 +- include/cutlass/arch/mma.h | 5 + include/cutlass/array.h | 132 +- include/cutlass/cluster_launch.hpp | 91 +- .../conv/collective/builders/sm100_common.inl | 193 + .../builders/sm100_umma_builder.inl | 225 + .../conv/collective/collective_builder.hpp | 1 + .../conv/collective/collective_conv.hpp | 1 + include/cutlass/conv/collective/detail.hpp | 21 +- ...100_implicit_gemm_umma_warpspecialized.hpp | 899 ++ .../conv/device/conv_universal_adapter.hpp | 4 +- include/cutlass/conv/dispatch_policy.hpp | 31 + .../cutlass/conv/kernel/conv_universal.hpp | 1 + .../conv/kernel/conv_universal_dispatch.hpp | 182 + ...m100_implicit_gemm_tma_warpspecialized.hpp | 911 ++ include/cutlass/core_io.h | 42 + include/cutlass/cuda_host_adapter.hpp | 17 + include/cutlass/detail/cluster.hpp | 99 + include/cutlass/detail/collective.hpp | 109 + .../detail/collective/mixed_input_utils.hpp | 11 +- include/cutlass/detail/layout.hpp | 28 + include/cutlass/detail/mma.hpp | 16 + .../detail/sm100_blockscaled_layout.hpp | 236 + include/cutlass/detail/sm100_tmem_helper.hpp | 76 + .../collective/builders/sm100_builder.inl | 1052 +++ .../collective/builders/sm90_builder.inl | 8 +- .../collective/collective_builder.hpp | 5 +- .../collective/collective_epilogue.hpp | 4 + .../epilogue/collective/default_epilogue.hpp | 31 +- .../collective/default_epilogue_array.hpp | 31 +- .../cutlass/epilogue/collective/detail.hpp | 350 + .../sm100_epilogue_array_nosmem.hpp | 453 + ...100_epilogue_array_tma_warpspecialized.hpp | 1190 +++ .../collective/sm100_epilogue_nosmem.hpp | 819 ++ .../sm100_epilogue_tma_warpspecialized.hpp | 1289 +++ include/cutlass/epilogue/dispatch_policy.hpp | 47 +- include/cutlass/epilogue/fusion/callbacks.hpp | 4 +- .../cutlass/epilogue/fusion/operations.hpp | 140 + .../sm100_callbacks_tma_warpspecialized.hpp | 955 +++ ...00_visitor_compute_tma_warpspecialized.hpp | 500 ++ ...m100_visitor_store_tma_warpspecialized.hpp | 338 + include/cutlass/exmy_base.h | 1219 +++ .../distributed/device/full_barrier.hpp | 2 +- include/cutlass/float8.h | 394 + include/cutlass/float_subbyte.h | 788 ++ include/cutlass/functional.h | 179 + .../sm100_blockscaled_umma_builder.inl | 782 ++ .../gemm/collective/builders/sm100_common.inl | 572 ++ .../builders/sm100_pipeline_carveout.inl | 117 + .../builders/sm100_umma_builder.inl | 320 + .../gemm/collective/collective_builder.hpp | 5 + .../gemm/collective/collective_mma.hpp | 6 + ..._blockscaled_mma_array_warpspecialized.hpp | 1268 +++ .../sm100_blockscaled_mma_warpspecialized.hpp | 1092 +++ .../sm100_mma_array_warpspecialized.hpp | 864 ++ .../collective/sm100_mma_warpspecialized.hpp | 723 ++ .../gemm/device/gemm_universal_adapter.h | 51 +- include/cutlass/gemm/dispatch_policy.hpp | 170 + .../cutlass/gemm/kernel/gemm_universal.hpp | 2 + .../sm100_gemm_array_tma_warpspecialized.hpp | 1142 +++ .../kernel/sm100_gemm_tma_warpspecialized.hpp | 1001 +++ .../gemm/kernel/sm100_tile_scheduler.hpp | 723 ++ .../kernel/sm100_tile_scheduler_group.hpp | 309 + .../kernel/sm100_tile_scheduler_stream_k.hpp | 979 +++ ..._array_tma_warpspecialized_cooperative.hpp | 1 + ...emm_array_tma_warpspecialized_pingpong.hpp | 1 + .../kernel/sm90_gemm_tma_warpspecialized.hpp | 4 +- .../kernel/sm90_tile_scheduler_stream_k.hpp | 15 + .../cutlass/gemm/kernel/tile_scheduler.hpp | 114 + .../gemm/kernel/tile_scheduler_params.h | 808 +- include/cutlass/integer_subbyte.h | 9 + include/cutlass/kernel_hardware_info.h | 3 + include/cutlass/numeric_conversion.h | 2653 +++++- include/cutlass/numeric_types.h | 2 + include/cutlass/pipeline/pipeline.hpp | 2 + include/cutlass/pipeline/sm100_pipeline.hpp | 918 ++ include/cutlass/pipeline/sm90_pipeline.hpp | 22 + include/cutlass/relatively_equal.h | 30 + include/cutlass/version.h | 2 +- media/docs/blackwell_functionality.md | 584 ++ media/docs/dependent_kernel_launch.md | 22 +- media/docs/efficient_gemm.md | 6 +- media/docs/fundamental_types.md | 14 + media/docs/profiler.md | 8 + media/docs/quickstart.md | 102 +- media/images/M128xK4_scalefactor_gmem.png | Bin 0 -> 224698 bytes ...ss-3.8-blackwell-gemm-peak-performance.svg | 1 + ...rrow_precison_multiple_block_sf_layout.png | Bin 0 -> 37653 bytes pyproject.toml | 2 +- python/cutlass/__init__.py | 2 +- python/cutlass_library/gemm_operation.py | 110 + python/cutlass_library/generator.py | 2309 ++++- python/cutlass_library/library.py | 105 + python/cutlass_library/manifest.py | 1 + python/setup_library.py | 2 +- python/setup_pycute.py | 2 +- test/self_contained_includes/CMakeLists.txt | 25 + test/unit/common/filter_architecture.cpp | 1 + test/unit/core/numeric_conversion.cu | 12 + test/unit/gemm/device/CMakeLists.txt | 152 +- test/unit/gemm/device/gemm_testbed_3x.hpp | 994 ++- .../gemm/device/gemm_testbed_3x_ptr_array.hpp | 406 + .../gemm_testbed_3x_tensor_broadcast.hpp | 8 + .../CMakeLists.txt | 150 + .../mxf4_mxf4_void_f16_nt_layout.cu | 303 + .../mxf4_mxf4_void_f16_tn_layout.cu | 523 ++ .../mxf4_mxf6_f32_f16_nt_layout.cu | 304 + .../mxf4_mxf6_f32_f16_tn_layout.cu | 524 ++ .../mxf4_mxf8_bf16_bf16_nt_layout.cu | 524 ++ .../mxf4_mxf8_bf16_bf16_tn_layout.cu | 524 ++ .../mxf6_mxf4_f16_f16_nt_layout.cu | 304 + .../mxf6_mxf4_f16_f16_tn_layout.cu | 524 ++ .../mxf6_mxf6_void_bf16_nt_layout.cu | 304 + .../mxf6_mxf6_void_bf16_tn_layout.cu | 524 ++ .../mxf6_mxf8_void_f32_nt_layout.cu | 523 ++ .../mxf6_mxf8_void_f32_tn_layout.cu | 524 ++ .../mxf8_mxf4_f16_bf16_nt_layout.cu | 304 + .../mxf8_mxf4_f16_bf16_tn_layout.cu | 523 ++ .../mxf8_mxf6_f16_f8_nt_layout.cu | 304 + .../mxf8_mxf6_f16_f8_tn_layout.cu | 524 ++ .../mxf8_mxf8_void_f8_nt_layout.cu | 523 ++ .../mxf8_mxf8_void_f8_tn_layout.cu | 524 ++ .../nvf4_nvf4_bf16_bf16.cu | 683 ++ .../nvf4_nvf4_bf16_bf16_features.cu | 374 + .../nvf4_nvf4_f16_nvfp4_epilogue.cu | 436 + ..._bf16_bf16_bf16_tensor_op_f32_ptr_array.cu | 364 + .../sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu | 323 + ...emm_f16_f16_f16_tensor_op_f16_ptr_array.cu | 364 + ...mm_f16_f16_f16_tensor_op_f32_group_gemm.cu | 606 ++ ...emm_f16_f16_f16_tensor_op_f32_ptr_array.cu | 665 ++ ...gemm_f16_f16_f16_tensor_op_f32_stream_k.cu | 250 + .../sm100_gemm_f16_f16_f32_tensor_op_f32.cu | 104 + ...emm_f16_f16_f32_tensor_op_f32_ptr_array.cu | 664 ++ ...mm_f32_f32_f32_tensor_op_f32_group_gemm.cu | 606 ++ ...emm_f32_f32_f32_tensor_op_f32_ptr_array.cu | 667 ++ ...gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu | 327 + ..._gemm_f4_f4_f32_tensor_op_f32_ptr_array.cu | 327 + ...4_f4_f32_tensor_op_f32_runtime_datatype.cu | 156 + ..._gemm_f6_f6_f32_tensor_op_f32_ptr_array.cu | 486 ++ ...6_f6_f32_tensor_op_f32_runtime_datatype.cu | 156 + ...8_f4_f32_tensor_op_f32_runtime_datatype.cu | 109 + ..._gemm_f8_f8_f8_tensor_op_f32_group_gemm.cu | 504 ++ ...0_gemm_f8_f8_f8_tensor_op_f32_ptr_array.cu | 465 + ...f8_f8_f8_tensor_op_f32_runtime_datatype.cu | 297 + ...f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu | 230 + ...0_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu | 284 + ...mxf4_mxf8_mxf8_tensor_op_f32_group_gemm.cu | 293 + ..._gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu | 281 + ...mxf8_mxf8_mxf8_tensor_op_f32_group_gemm.cu | 293 + ...2t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu | 3 + test/unit/gemm/threadblock/mma_multistage.cu | 42 + test/unit/pipeline/CMakeLists.txt | 1 + ...ontrol_async_warp_specialized_blackwell.cu | 381 + .../pipeline/testbed_cluster_launch_control.h | 154 + tools/library/CMakeLists.txt | 13 + .../include/cutlass/library/arch_mappings.h | 12 + .../include/cutlass/library/descriptions.h | 95 + .../library/include/cutlass/library/handle.h | 9 + .../library/include/cutlass/library/library.h | 60 + .../include/cutlass/library/operation_table.h | 188 + tools/library/include/cutlass/library/types.h | 28 + tools/library/include/cutlass/library/util.h | 11 + .../src/block_scaled_gemm_operation_3x.hpp | 450 + tools/library/src/gemm_operation_3x.hpp | 62 + tools/library/src/handle.cu | 13 + tools/library/src/library_internal.h | 42 + tools/library/src/operation_table.cu | 39 + .../reference/block_scaled_gemm_fp4a_vs16.cu | 128 + .../reference/block_scaled_gemm_fp4a_vs32.cu | 130 + .../block_scaled_gemm_mixed8bitsa.cu | 354 + .../block_scaled_gemm_reference_operation.h | 459 + tools/library/src/reference/gemm_f4_f4_f32.cu | 109 + tools/library/src/reference/gemm_f4_f6_f32.cu | 110 + tools/library/src/reference/gemm_f4_f8_f32.cu | 110 + tools/library/src/reference/gemm_f6_f4_f32.cu | 110 + tools/library/src/reference/gemm_f6_f6_f32.cu | 109 + tools/library/src/reference/gemm_f6_f8_f32.cu | 110 + tools/library/src/reference/gemm_f8_f4_f32.cu | 110 + tools/library/src/reference/gemm_f8_f6_f32.cu | 110 + tools/library/src/reference/gemm_u8_u8_s32.cu | 11 + .../initialize_reference_operations.cu | 26 + .../library/src/sparse_gemm_operation_3x.hpp | 4 + tools/library/src/util.cu | 267 + tools/profiler/CMakeLists.txt | 2 + .../block_scaled_gemm_operation_profiler.h | 290 + .../profiler/gemm_operation_profiler.h | 14 + .../include/cutlass/profiler/problem_space.h | 12 + .../block_scaled_gemm_operation_profiler.cu | 1371 +++ tools/profiler/src/cutlass_profiler.cu | 3 + tools/profiler/src/device_allocation.cu | 325 + tools/profiler/src/device_context.cu | 19 + tools/profiler/src/gemm_operation_profiler.cu | 108 + tools/profiler/src/operation_profiler.cu | 17 + tools/profiler/src/options.cu | 25 +- tools/profiler/src/problem_space.cpp | 41 + .../util/reference/device/convolution.h | 2 +- .../cutlass/util/reference/host/gemm.h | 8 +- .../cutlass/util/reference/host/gett.hpp | 391 +- 290 files changed, 91885 insertions(+), 954 deletions(-) create mode 100644 ACTIVE_DEVELOPERS.md delete mode 100644 CONTRIBUTORS.md create mode 100644 examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu create mode 100644 examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu create mode 100644 examples/70_blackwell_gemm/CMakeLists.txt create mode 100644 examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu create mode 100644 examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt create mode 100644 examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu create mode 100644 examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu create mode 100644 examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu create mode 100644 examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt create mode 100644 examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt create mode 100644 examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu create mode 100644 examples/74_blackwell_gemm_streamk/CMakeLists.txt create mode 100644 examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu create mode 100644 examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu create mode 100644 examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu create mode 100644 examples/75_blackwell_grouped_gemm/CMakeLists.txt create mode 100644 examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu create mode 100644 examples/76_blackwell_conv/76_blackwell_conv_fprop.cu create mode 100644 examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu create mode 100644 examples/76_blackwell_conv/CMakeLists.txt create mode 100644 examples/77_blackwell_fmha/77_blackwell_fmha.cu create mode 100644 examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu create mode 100644 examples/77_blackwell_fmha/CMakeLists.txt create mode 100644 examples/77_blackwell_fmha/README.md create mode 100644 examples/77_blackwell_fmha/collective/fmha_common.hpp create mode 100644 examples/77_blackwell_fmha/collective/fmha_fusion.hpp create mode 100644 examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/device/fmha.hpp create mode 100644 examples/77_blackwell_fmha/kernel/fmha_options.hpp create mode 100644 examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp create mode 100644 examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp create mode 100644 examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp create mode 100644 examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp create mode 100644 examples/77_blackwell_fmha/reference/reference_abs_error.hpp create mode 100644 examples/README.md create mode 100755 include/cute/arch/cluster_sm100.hpp create mode 100644 include/cute/arch/copy_sm100.hpp create mode 100644 include/cute/arch/copy_sm100_tma.hpp create mode 100644 include/cute/arch/mma_sm100.hpp create mode 100644 include/cute/arch/mma_sm100_desc.hpp create mode 100644 include/cute/arch/mma_sm100_umma.hpp create mode 100644 include/cute/arch/simd_sm100.hpp create mode 100644 include/cute/arch/tmem_allocator_sm100.hpp create mode 100644 include/cute/atom/copy_traits_sm100.hpp create mode 100644 include/cute/atom/copy_traits_sm100_im2col.hpp create mode 100644 include/cute/atom/copy_traits_sm100_tma.hpp create mode 100644 include/cute/atom/mma_traits_sm100.hpp create mode 100644 include/cute/atom/partitioner.hpp create mode 100644 include/cutlass/conv/collective/builders/sm100_common.inl create mode 100644 include/cutlass/conv/collective/builders/sm100_umma_builder.inl create mode 100644 include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp create mode 100644 include/cutlass/conv/kernel/conv_universal_dispatch.hpp create mode 100644 include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp create mode 100644 include/cutlass/detail/cluster.hpp create mode 100644 include/cutlass/detail/sm100_blockscaled_layout.hpp create mode 100644 include/cutlass/detail/sm100_tmem_helper.hpp create mode 100644 include/cutlass/epilogue/collective/builders/sm100_builder.inl create mode 100644 include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp create mode 100644 include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp create mode 100644 include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp create mode 100644 include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp create mode 100644 include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp create mode 100644 include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp create mode 100644 include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp create mode 100644 include/cutlass/exmy_base.h create mode 100644 include/cutlass/float_subbyte.h create mode 100644 include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl create mode 100644 include/cutlass/gemm/collective/builders/sm100_common.inl create mode 100644 include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl create mode 100644 include/cutlass/gemm/collective/builders/sm100_umma_builder.inl create mode 100644 include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp create mode 100644 include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp create mode 100644 include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp create mode 100644 include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp create mode 100644 include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp create mode 100644 include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp create mode 100755 include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp create mode 100755 include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp create mode 100644 include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp create mode 100644 include/cutlass/pipeline/sm100_pipeline.hpp create mode 100644 media/docs/blackwell_functionality.md create mode 100644 media/images/M128xK4_scalefactor_gmem.png create mode 100644 media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg create mode 100644 media/images/narrow_precison_multiple_block_sf_layout.png create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16_features.cu create mode 100644 test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_f16_nvfp4_epilogue.cu create mode 100644 test/unit/gemm/device/sm100_gemm_bf16_bf16_bf16_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f16_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_group_gemm.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_group_gemm.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu create mode 100644 test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu create mode 100644 test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu create mode 100644 test/unit/gemm/device/sm100_gemm_mxf4_mxf8_mxf8_tensor_op_f32_group_gemm.cu create mode 100644 test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu create mode 100644 test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_group_gemm.cu create mode 100644 test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu create mode 100644 test/unit/pipeline/testbed_cluster_launch_control.h create mode 100644 tools/library/src/block_scaled_gemm_operation_3x.hpp create mode 100644 tools/library/src/reference/block_scaled_gemm_fp4a_vs16.cu create mode 100644 tools/library/src/reference/block_scaled_gemm_fp4a_vs32.cu create mode 100644 tools/library/src/reference/block_scaled_gemm_mixed8bitsa.cu create mode 100644 tools/library/src/reference/block_scaled_gemm_reference_operation.h create mode 100644 tools/library/src/reference/gemm_f4_f4_f32.cu create mode 100644 tools/library/src/reference/gemm_f4_f6_f32.cu create mode 100644 tools/library/src/reference/gemm_f4_f8_f32.cu create mode 100644 tools/library/src/reference/gemm_f6_f4_f32.cu create mode 100644 tools/library/src/reference/gemm_f6_f6_f32.cu create mode 100644 tools/library/src/reference/gemm_f6_f8_f32.cu create mode 100644 tools/library/src/reference/gemm_f8_f4_f32.cu create mode 100644 tools/library/src/reference/gemm_f8_f6_f32.cu create mode 100644 tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h create mode 100644 tools/profiler/src/block_scaled_gemm_operation_profiler.cu diff --git a/ACTIVE_DEVELOPERS.md b/ACTIVE_DEVELOPERS.md new file mode 100644 index 0000000000..6ae47b4373 --- /dev/null +++ b/ACTIVE_DEVELOPERS.md @@ -0,0 +1,73 @@ +![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") + +[README](./README.md#documentation) > **Active Developers** + +# CUTLASS Developers ** + +Andrew Kerr (CUTLASS founding member)
+Dustyn Blasig
+Albert Xu
+Junkai Wu
+Xiuxia Zhang
+Haicheng Wu (CUTLASS founding member)
+Jack Yang
+Pradeep Ramani (CUTLASS 3.x founding member)
+Aditya Atluri
+Han Li
+Nick Zhao
+Ivan Yin
+Yu-Jung Chen
+Markus Hoehnerbach
+Honghao Lu
+Mihir Awatramani
+Hao Sheng
+Zekun Fan
+Aniket Shivam
+Siyu Liu
+Richard Cai
+Vikas Gupta
+Ethan Yan
+Vijay Thakkar (CUTLASS 3.x founding member)
+Cris Cecka (CuTe and CUTLASS 3.x founding member)
+Lawrence Ryan
+Qun Song
+Daniel Ricketts
+dePaul Miller
+Yuhan Li
+Saman Ashkiani
+Jack Chen
+Shang Zhang
+Petrick Liu
+Questa Wang
+Pramod Shenoy
+Jack Kosaian
+Yujia Zhai
+Zhaodong Chen
+Manas Sahni
+Shunfan Shao
+Fengqi Qiao
+Serif Yesil
+Aragorn Guan
+Heidi He
+Xiao Song
+Sergey Klevtsov
+Jiang Shao
+Ruqing Xu
+Mengyu Guo
+Tao Xie
+Linfeng Zheng
+Harrison Barclay
+Wenfei Tang
+Diksha Gohlyan
+Alexander Zhurkevich
+Siyuan Fu
+Hua Huang
+Xiufan Liang
+Ian Tramble
+Ali Hassani
+Shreya Gaur
+ +** _The list is sorted in order of the author's first contribution to the CUTLASS project._ + +# CUTLASS Product Manager +Matthew Nicely
diff --git a/CHANGELOG.md b/CHANGELOG.md index 76f6bf6a91..95419bcb38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,59 @@ # NVIDIA CUTLASS Changelog + +## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25) + +* Support for new CuTe building blocks specifically for Blackwell SM100 architecture: + - [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms. + - Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms. + - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale. + - Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe. + - [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms. + - Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms. +* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture: + - Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h) + - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). + - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). + - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./cutlass/media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. +* Full support for Blackwell SM100 kernels in CUTLASS 3.x API: + - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that + + Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture. + + Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators. + + Support stream-K load balancing for all kernel types everywhere via composable scheduler support. + - Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for + * [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp) + * [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp) + * [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp) + * [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp) + - Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad. + - New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions](). +* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification. + - Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes. + - Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors. +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture: + - [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API. + - GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell. + - Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores: + + [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) + + [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) + + [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) + - GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy. + - [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu). + - Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu). + - Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu). + - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. +* Documentation updates: + - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). + - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/narrow_and_mixed_precision_gemms.md) + - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. + - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). + ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). - [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b). -- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). - Enabled high precision accumulation for Hopper FP8 Sparse GEMM. - Potential API breaking changes: + Fix `cute::UniversalCopy` for type safety. @@ -22,12 +73,7 @@ + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. -- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md). - + Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - + Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading. - + Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`. - + Other general optimizations. -- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead. +- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). - [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). - [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. diff --git a/CMakeLists.txt b/CMakeLists.txt index e50fd76e11..9892f067d6 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -164,6 +164,11 @@ endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a) endif() + +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a) +endif() + set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") @@ -383,6 +388,21 @@ endif() + +################################################################################################### +# +# Blackwell features +# +################################################################################################### + +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUDA_BLACKWELL_TMA_SWIZZLE_ENABLED=1) + + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUDA_ENABLE_PREFERRED_CLUSTER=1) +endif() + + + # Warnings-as-error exceptions and warning suppressions for Clang builds if (CUTLASS_CLANG_HOST_COMPILE) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md deleted file mode 100644 index 538bb65843..0000000000 --- a/CONTRIBUTORS.md +++ /dev/null @@ -1,87 +0,0 @@ -![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") - -[README](./README.md#documentation) > **Contributors** - -# CUTLASS Developers and Contributors - -This is the official list of CUTLASS developers and contributors. - -## DEVELOPERS -Vijay Thakkar
-Pradeep Ramani
-Cris Cecka
-Aniket Shivam
-Jack Kosaian
-Mark Hoemmen
-Richard Cai
-Honghao Lu
-Ethan Yan
-Haicheng Wu
-Andrew Kerr
-Dustyn Blasig
-Fengqi Qiao
-Duane Merrill
-Yujia Zhai
-Rawn Henry
-Sergey Klevtsov
-Shang Zhang
-Piotr Majcher
-Paul Springer
-Markus Hohnerbach
-Jin Wang
-Aditya Atluri
- -## CuTe -Cris Cecka
-Vijay Thakkar
- -## CUTLASS Product Manager -Matthew Nicely
- -## Former CUTLASS Developers -Manish Gupta
-Naila Farooqui
-David Tanner
-Manikandan Ananth
-Zhaodong Chen
-Chinmay Talegaonkar
- -## CONTRIBUTORS -Timothy Costa
-Julien Demouth
-Brian Fahs
-Michael Garland
-Michael Goldfarb
-Mostafa Hagog
-Fei Hu
-Alan Kaatz
-Tina Li
-Timmy Liu
-Wei Liu
-Tim Martin
-Duane Merrill
-Kevin Siu
-Markus Tavenrath
-John Tran
-Vicki Wang
-Junkai Wu
-Fung Xie
-Albert Xu
-Yang Xu
-Jack Yang
-Scott Yokim
-Xiuxia Zhang
-Nick Zhao
- -## ACKNOWLEDGEMENTS - -Girish Bharambe
-Luke Durant
-Carter Edwards
-Olivier Giroux
-Stephen Jones
-Rishkul Kulkarni
-Bryce Lelbach
-Joel McCormack
-Kyrylo Perelygin
-Sean Treichler
diff --git a/README.md b/README.md index 56637806c9..a74ac114b4 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.7.0 +# CUTLASS 3.8.0 -_CUTLASS 3.7.0 - January 2025_ +_CUTLASS 3.8.0 - January 2025_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -16,71 +16,96 @@ as building blocks within custom kernels and applications. To support a wide variety of applications, CUTLASS provides extensive support for mixed-precision computations, providing specialized data-movement and -multiply-accumulate abstractions for half-precision floating -point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32), -single-precision floating point (FP32), -[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm), -double-precision floating -point (FP64) types, integer data types (4b and 8b), and binary data types (1b). -CUTLASS demonstrates warp-synchronous matrix multiply operations +multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16, +[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm), + 8b floating point types (e5m2 and e4m3), + block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8), + narrow integer types (4 and 8b signed and unsigned integers), + and binary 1b data types (where architectures allow for the +native support of such data types). +CUTLASS demonstrates optimal matrix multiply operations targeting the programmable, high-throughput _Tensor Cores_ implemented by -NVIDIA's Volta, Turing, Ampere, and Hopper architectures. +NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures. -See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly. - -See the [functionality listing](./media/docs/functionality.md) for the list of operations -supported at each level of the execution model hierarchy. - -CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data. -CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. - -The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. - -CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design -and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cute/00_quickstart.md). - -In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. - -# What's New in CUTLASS 3.7 - -CUTLASS 3.7.0 is an update to CUTLASS adding: - -- A new [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) where the operands and block scaling tensor are staged via shared memory. -- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is an experimental pipelined Tensor Parallelism implementation utilizing existing CUTLASS kernels and CUDA runtime features, which can hide the most of communication behind computation. -- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). -- Enabled high precision accumulation for Hopper FP8 Sparse GEMM. -- Potential API breaking changes: - + Fix `cute::UniversalCopy` for type safety. - + No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom. - + Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication. - + Remove `cute::copy_vec` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment,...)`. - + A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel. -- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler). -- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +In addition to GEMMs, CUTLASS implements high-performance convolution via +the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution +operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. +This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -Minimum requirements: - -- Architecture: Volta -- Compiler: Must support at least C++17 -- CUDA Toolkit version: 11.4 - -Starting from CUTLASS 3.0, CUTLASS removed support for the following: - -- Maxwell and Pascal GPU architectures -- Ubuntu 16.04 -- CUDA 10.2 -- C++ language versions less than 17. +See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly. -**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.** +See the [functionality docs](./media/docs/functionality.md) for a more comprehensive +list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU +architecture. + +# What's New in CUTLASS 3.8 + +CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture. +For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8. + +* Support for new CuTe building blocks specifically for Blackwell architecture: + - [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms. + - Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms. + - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp#L290) across CuTe as a first class data locale. + - Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe. + - [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms. + - Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms. +* Support for new CUTLASS building blocks specifically for Blackwell architecture: + - Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h) + - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). + - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). + - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./cutlass/media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. +* Full support for Blackwell kernels in CUTLASS 3.x API: + - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that + + Implement a new warp-specialization recipe tuned specifically for Blackwell. + + Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators. + + Support stream-K load balancing for all kernel types everywhere via composable scheduler support. + - Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for + * [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp) + * [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp) + * [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp) + * [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp) + - Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad. + - New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions](). +* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification. + - Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes. + - Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors. +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell + - [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API. + - GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell. + - Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores: + + [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) + + [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) + + [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) + - GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy. + - [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu). + - Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu). + - Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu). + - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. +* Documentation updates: + - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). + - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/narrow_and_mixed_precision_gemms.md) + - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. + - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). + +Note: CUTLASS 3.x builds are known to be broken on Windows platforms for all CUDA toolkits. +CUTLASS team is working on a fix. + +**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.** # Performance -

-

- CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels, -they exhibit peak performance comparable to cuBLAS for scalar GEMM -computations. The above figure shows the continual CUTLASS performance improvements +they exhibit nearly optimal utilization of peak theoretical throughput. The figure below +shows CUTLASS 3.8's performance as a % of theoretical peak utilization +on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU. + +

+ +The two figures below show the continual CUTLASS performance improvements on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since CUTLASS 3.1. CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads). @@ -88,20 +113,45 @@ Tensor Core operations are implemented using CUDA's [mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and [wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions. -

+

+

-When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad) -kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) -as shown in the above figure. Tensor Core operations are implemented using CUDA's -[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). +# CuTe + +CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data. +CuTe is a collection of C++ CUDA template abstractions for +defining and operating on hierarchically multidimensional layouts of threads and data. +CuTe provides `Layout` and `Tensor` objects that compactly package the type, +shape, memory space, and layout of data, while performing the complicated indexing for the user. +This lets programmers focus on the logical descriptions of their algorithms while +CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, +implement, and modify all dense linear algebra operations. + +The core abstractions of CuTe are hierarchically multidimensional layouts +which can be composed with data arrays to represent tensors. +The representation of layouts is powerful enough to represent nearly +everything we need to implement efficient dense linear algebra. +Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. + +CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. +This greatly simplifies the design and improves code composability and readability. +More documentation specific to CuTe can be found in its +[dedicated documentation directory](./media/docs/cute/00_quickstart.md). # Compatibility +Minimum requirements: + +- Architecture: Volta (compute capability 7.0) +- Compiler: Must support at least C++17 +- CUDA Toolkit version: 11.4 + CUTLASS requires a C++17 host compiler and -performs best when built with the [**CUDA 12.4 Toolkit**](https://developer.nvidia.com/cuda-downloads). -It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2. +performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions. ## Operating Systems + We have tested the following environments. |**Operating System** | **Compiler** | @@ -109,47 +159,74 @@ We have tested the following environments. | Ubuntu 18.04 | GCC 7.5.0 | | Ubuntu 20.04 | GCC 10.3.0 | | Ubuntu 22.04 | GCC 11.2.0 | -| Ubuntu 22.04 | Clang 10.0.0 | -| Ubuntu 22.04 | Clang 14.0.6 | -| Ubuntu 22.04 | Clang 17.0.6 | -| Windows 10.0 | Visual Studio 2019 v16.11.27 | Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. +Note: CUTLASS 3.x builds are known to be broken on Windows platforms for all CUDA toolkits. +CUTLASS team is working on a fix. + ## Hardware + CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs. |**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**| |---|---|---| |NVIDIA V100 Tensor Core GPU |7.0|11.4| |NVIDIA TitanV |7.0|11.4| -|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4| +|NVIDIA GeForce RTX 20x0 series |7.5|11.4| |NVIDIA T4 |7.5|11.4| |NVIDIA A100 Tensor Core GPU |8.0|11.4| |NVIDIA A10 |8.6|11.4| -|NVIDIA GeForce RTX 3090 |8.6|11.4| -|NVIDIA GeForce RTX 4090 |8.9|11.8| +|NVIDIA GeForce RTX 30x0 series |8.6|11.4| +|NVIDIA GeForce RTX 40x0 series |8.9|11.8| |NVIDIA L40 |8.9|11.8| |NVIDIA H100 Tensor Core GPU |9.0|11.8| +|NVIDIA H200 Tensor Core GPU |9.0|11.8| +|NVIDIA B200 Tensor Core GPU |10.0|12.8| ## Target Architecture -In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). +In general, PTX code generated for one target architecture can be run on future architectures +(i.e., it is forward compatible). +However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose +PTX does not have forward compatibility guarantees. +Several Hopper and Blackwell PTX instructions fall under this category of +architecture-accelerated features, and thus require a `sm_90a` or `sm100a` target architecture +(note the "a" appended). For more details on this and other architecture-accelerated instructions, +please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). + +The target architecture information is passed on to CUTLASS via the cmake flag +`CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, +users are required to build CUTLASS with `90a` as the target architecture. +If a user accidentally builds a kernel which uses SM90a features +(e.g. Hopper Tensor Core Instructions), using the SM90 target +(note the lack of "a"), with either CUDA Toolkit 12 or 11.8, +the kernel is expected to fail with a runtime error. -The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error. +``` +cmake .. -DCUTLASS_NVCC_ARCHS="90a" +``` +Or ``` -cmake .. -DCUTLASS_NVCC_ARCHS="90a" +cmake .. -DCUTLASS_NVCC_ARCHS="100a" ``` -Please refer to the [functionality documentation](./media/docs/functionality.md) for details on which kernels require which target architectures. +Note: The NVIDIA Blackwell SM100 architecture used in the datacenter +products has a different compute capability than the one underpinning +NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels +compiled for Blackwell SM100 architecture with arch conditional features +(using `sm100a`) are not compatible with RTX 50 series GPUs. + +Please refer to the [functionality documentation](./media/docs/functionality.md) +for details on which kernels require which target architectures. # Documentation CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). -- [Quick Start Guide](./media/docs/quickstart.md) - build and run CUTLASS +- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS - [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS - [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA - [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components @@ -163,7 +240,7 @@ CUTLASS is described in the following documents and the accompanying - [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory - [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory - [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application -- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development +- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development - [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent kernels in the same stream, and how it is used in CUTLASS. @@ -171,11 +248,11 @@ kernels in the same stream, and how it is used in CUTLASS. We have also described the structure of an efficient GEMM in our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). - - [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/) - - [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/) - - [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/) - - [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/) - - [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/) +- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/) +- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/) +- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/) +- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/) +- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/) # Building CUTLASS diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index 7ed4593cf0..97e9061ed9 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -489,6 +489,14 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + if (props.major != 9 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index 5852cd8d02..ee8415263e 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -540,6 +540,15 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; return 0; } + + else if (__CUDACC_VER_MAJOR__ < 12 || props.major != 9 || props.minor != 0) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture " + << "(compute capability 90) and CUDA 12.0 or greater.\n"; + return 0; + } + + // // Parse options // diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu index 69a3c030b0..6e91a3ba57 100644 --- a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -356,6 +356,15 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; return 0; } + + else if (__CUDACC_VER_MAJOR__ < 12 || props.major != 9 || props.minor != 0) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture " + << "(compute capability 90) and CUDA 12.0 or greater.\n"; + return 0; + } + + // // Parse options // diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index 8a19842067..49505284db 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -626,6 +626,13 @@ int main(int argc, const char ** argv) { std::cerr << "This example requires a device with compute capability 90 or higher.\n"; notSupported = true; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + notSupported = true; + } + + if (notSupported) { return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems } diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu index d8096d9eb9..66ab9b241f 100644 --- a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -750,6 +750,13 @@ int main(int argc, char const **argv) std::cerr << "This example requires a device with compute capability 90 or higher.\n"; notSupported = true; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + notSupported = true; + } + + if (notSupported) { return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems } diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu index a6f33d1cd0..f250e4b93d 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -572,6 +572,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu index 8bca0a35ce..bdb59bfdcb 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -619,6 +619,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index 7bc65f9ba5..581ccf88ae 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -524,6 +524,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index cf64f37b96..2629ee3396 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -489,6 +489,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 48eca35c2d..ca64c90136 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -7,14 +7,18 @@ When relying on `KernelScheduleAuto`, the main loop supporting different A and B This first version only supports mixed type GEMMs using TMA. + ## Performance -While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type. +While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type. The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now. +Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details. + We are currently optimizing the following cases: 1. Memory bound cases for all types +2. `fp8 x {int2, uint2}` case ## Limitations diff --git a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp index 56382d8386..724adabca9 100644 --- a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -151,16 +151,16 @@ void mixed_dtype_profiling( runtimes.reserve(options.iterations); for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { - cudaEventRecord(start); - CUTLASS_CHECK(gemm.run()); - cudaEventRecord(stop); - cudaEventSynchronize(stop); - - if (iter >= options.warmup) { - float milliseconds = 0; - cudaEventElapsedTime(&milliseconds, start, stop); - runtimes.push_back(milliseconds); - } + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } } cudaEventDestroy(start); diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp index 98d6df550e..a595ca7245 100644 --- a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -33,6 +33,8 @@ #include +#include "cutlass/util/device_memory.h" +#include "cutlass/integer_subbyte.h" #include "cutlass/float8.h" #include "cutlass/util/reference/device/tensor_fill.h" @@ -197,7 +199,6 @@ bool initialize_packed_scale( { cutlass::packed_scale_t tmp(data_in[i]); data_out[i] = reinterpret_cast const&>(tmp); - // std::cout << data_in[i] << ":" << std::hex << static_cast(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast((-data_in[i]).storage) << std::endl; } try { block_out.copy_from_host(data_out.data()); diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index 886d39a267..ec29bc0544 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -519,6 +519,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index 3aeafb4df5..0f014cc760 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -737,6 +737,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu index a71a63ebbd..ac21697ee7 100644 --- a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu +++ b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu @@ -507,6 +507,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu index 01a10046db..708d1db66f 100644 --- a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu +++ b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu @@ -576,6 +576,14 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + if (props.major != 9 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu index d1db304b1f..03b54f3e01 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu +++ b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu @@ -475,6 +475,13 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } + + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // // Parse options // diff --git a/examples/65_distributed_gemm/65_distributed_gemm.cu b/examples/65_distributed_gemm/65_distributed_gemm.cu index f0b59ca3ce..2289d62a8a 100644 --- a/examples/65_distributed_gemm/65_distributed_gemm.cu +++ b/examples/65_distributed_gemm/65_distributed_gemm.cu @@ -133,7 +133,8 @@ using namespace cute; using TP = _8; static constexpr int TP_ = TP{}; -#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) // Distributed GEMM tiling/sharding schedule // Choices: @@ -344,7 +345,8 @@ struct Result { }; -#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation diff --git a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu new file mode 100644 index 0000000000..39123cacf6 --- /dev/null +++ b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu @@ -0,0 +1,483 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A FP16 dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + + This example demonstrates minimal set of changes needed to transition from a Hopper CUTLASS 3.x + GEMM kernel (see example 48_hopper_warp_specialized_gemm) to a Blackwell 3.x CUTLASS GEMM kernel. + + The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features: + + 1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a) + which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA). + + Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + 2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a). + Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the + Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/70_blackwell_gemm/70_blackwell_fp16_gemm --m=8192 --n=8192 --k=8192 +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_64>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_2,_1>; +// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 +using AtomThrShape_MNK = Shape<_2, _1, _1>; +// Shape of the tile computed by each SM +using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(8192), n(8192), k(8192), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "70_blackwell_fp16_gemm\n\n" + << " Blackwell FP16 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "70_blackwell_fp16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 100a. + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu new file mode 100644 index 0000000000..0b1758b90f --- /dev/null +++ b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A FP8 dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + + This example demonstrates minimal set of changes needed to transition from a Hopper CUTLASS 3.x + FP8 GEMM kernel (see example 54_hopper_fp8_warp_specialized_gemm) to a Blackwell SM100 FP8 GEMM kernel. + + This example shows all important fusions used by FP8 gemm kernels, + i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor. + + The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features: + + 1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a) + which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA). + + Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + 2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a). + Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the + Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/70_blackwell_gemm/70_blackwell_fp8_gemm --m=8192 --n=8192 --k=8192 +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; + +// Epilogue types +using ElementBias = cutlass::half_t; +using ElementCompute = float; +using ElementAux = ElementC; +using LayoutAux = LayoutC; +using ElementAmax = float; + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_64>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_2,_1>; +// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 +using AtomThrShape_MNK = Shape<_2, _1, _1>; +// Shape of the tile computed by each SM +using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); + +using FusionOp = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutC, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutC, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOp + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ElementAmax = typename EpilogueOutputOp::ElementAmax; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; +using StrideAux = StrideC; + +constexpr bool IsDFp8 = + cute::is_same_v or + cute::is_same_v; + +constexpr bool IsAuxFp8 = + cute::is_same_v or + cute::is_same_v; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +StrideAux stride_aux; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; +cutlass::HostTensor tensor_aux; +cutlass::HostTensor tensor_ref_aux; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; +cutlass::HostTensor scale_A; +cutlass::HostTensor scale_B; +cutlass::HostTensor scale_C; +cutlass::HostTensor scale_D; +cutlass::HostTensor scale_aux; +cutlass::HostTensor abs_max_D; +cutlass::HostTensor reference_abs_max_D; +cutlass::HostTensor abs_max_aux; +cutlass::HostTensor reference_abs_max_aux; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "70_blackwell_fp8_gemm\n\n" + << " Blackwell FP8 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "70_blackwell_fp8_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + stride_aux = stride_D; + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + initialize_tensor(tensor_C.host_view(), seed + 2024); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + if (options.save_aux) { + tensor_aux.resize(c_coord); + tensor_aux.sync_device(); + tensor_ref_aux.resize(c_coord); + } + + if (options.device_scale) { + scalar_alpha.resize(cutlass::make_Coord(1)); + scalar_beta.resize(cutlass::make_Coord(1)); + scale_A.resize(cutlass::make_Coord(1)); + scale_B.resize(cutlass::make_Coord(1)); + scale_C.resize(cutlass::make_Coord(1)); + scale_D.resize(cutlass::make_Coord(1)); + scale_aux.resize(cutlass::make_Coord(1)); + + cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha); + cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta); + cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a); + cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b); + cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c); + cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d); + cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux); + + scalar_alpha.sync_device(); + scalar_beta.sync_device(); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + scale_aux.sync_device(); + } + + if (IsDFp8 && options.save_amax) { + abs_max_D.resize(cutlass::make_Coord(1)); + abs_max_D.sync_device(); + reference_abs_max_D.resize(cutlass::make_Coord(1)); + } + + if (IsAuxFp8 && options.save_aux && options.save_amax) { + abs_max_aux.resize(cutlass::make_Coord(1)); + abs_max_aux.sync_device(); + reference_abs_max_aux.resize(cutlass::make_Coord(1)); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + fusion_args.scale_a = options.scale_a; + fusion_args.scale_b = options.scale_b; + fusion_args.scale_c = options.scale_c; + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + + // ignored if tensor types are not fp8 + fusion_args.scale_d = options.scale_d; + fusion_args.scale_aux = options.scale_aux; + fusion_args.scale_d_ptr = scale_D.device_data(); + fusion_args.scale_aux_ptr = scale_aux.device_data(); + + // leaving/setting these as nullptr disables the fusion at runtime + fusion_args.bias_ptr = nullptr; + + if (options.save_aux) { + fusion_args.aux_ptr = tensor_aux.device_data(); + fusion_args.dAux = stride_aux; + if (options.save_amax) { + fusion_args.amax_aux_ptr = abs_max_aux.device_data(); + } + } + + if (options.save_amax) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto Aux = cute::make_tensor(tensor_ref_aux.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_aux)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + decltype(Aux), + unused_t, // valpha + unused_t, // vbeta + ActivationFunctor + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.Aux = Aux; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + epilogue_params.scale_a = options.scale_a; + epilogue_params.scale_b = options.scale_b; + epilogue_params.scale_c = options.scale_c; + epilogue_params.scale_d = options.scale_d; + epilogue_params.scale_aux = options.scale_aux; + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data(); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + if (IsDFp8 && options.save_amax) { + abs_max_D.sync_host(); + passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0)); + } + + if (options.save_aux) { + tensor_aux.sync_host(); + passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view()); + if (IsAuxFp8 && options.save_amax) { + abs_max_aux.sync_host(); + passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0)); + } + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/70_blackwell_gemm/CMakeLists.txt b/examples/70_blackwell_gemm/CMakeLists.txt new file mode 100644 index 0000000000..d88a8c56e6 --- /dev/null +++ b/examples/70_blackwell_gemm/CMakeLists.txt @@ -0,0 +1,41 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 70_blackwell_fp16_gemm + 70_blackwell_fp16_gemm.cu +) + +cutlass_example_add_executable( + 70_blackwell_fp8_gemm + 70_blackwell_fp8_gemm.cu +) +endif() diff --git a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu new file mode 100644 index 0000000000..6712d7a9f3 --- /dev/null +++ b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules + and epilogue visitor tree (EVT) construction + + Example usage: + $ ./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder \ + --m=2048 --n=2048 --k=2048 --l=2 +*/ + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l; + float alpha, beta; + + Options(): + help(false), + error(false), + m(2048), n(2048), k(2048), l(1), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "71_blackwell_gemm_with_collective_builder\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " performant kernels targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective +// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the +// number of pipeline stages. +template < + // Type of kernel schedule to generate + class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto, + // Type of epilogue schedule to generate + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, + // Number of pipeline stages to use + class StageCountType = cutlass::gemm::collective::StageCountAuto, + // Do we use custom epilogue visitor tree (EVT) fusion + bool UseCustomEVT = false +> +struct ExampleRunner { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + using ClusterShapeMNK = Shape<_2,_2,_1>; + static constexpr bool Use2SmMma = + // Manually specified 2sm cluster MMA schedule, will error if cluster M is not a multiple of 2 + std::is_same_v || + // Auto schedule will try to select 2sm cluster MMA based on cluster M + std::is_same_v && size<0>(ClusterShapeMNK{}) % 2 == 0; + // The MNK layout of CTAs within a cluster MMA + using AtomThrMNK = std::conditional_t, Shape<_1,_1,_1>>; + // The MMA tile used by the mainloop collective. Blackwell 1sm MMA supports up to MMA tile M = 128, 2sm MMA supports up to MMA tile M = 256 + using MmaTileMNK = std::conditional_t, Shape<_128,_128,_64>>; + // The Output tile used by the epilogue collective + using OutputTileMNK = decltype(shape_div(MmaTileMNK{}, AtomThrMNK{})); + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + + // Blackwell fusions for the most part use the same EVT nodes used in Hopper. Most Blackwell EVTs will alias to their Hopper counterparts. + // EVT nodes new to Blackwell mainly relate to narrow precision scale factor generation and are contained in include/cutlass/epilogue/fusion/sm100_visitor_*.hpp + // See include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp for EVT construction using these new nodes + // Fusions relating to narrow-precision scale factor generation are demonstrated in example 72b and can only be used in blackwell kernels + using CustomEVT = // alpha * acc + beta * C + cutlass::epilogue::fusion::Sm90EVT, // beta * C + (alpha * acc) + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // beta + cutlass::epilogue::fusion::Sm90SrcFetch, // C + cutlass::epilogue::fusion::Sm90EVT, // alpha * acc + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha + cutlass::epilogue::fusion::Sm90AccFetch // acc + > + >; + + // As in Hopper, a predefined set of fusion operations are provided in include/cutlass/epilogue/fusion/operations.hpp and can be passed to the epilogue builder + // Fusions operations supported by the Hopper TMA epilogue will also be supported by the Blackwell TMA epilogue + // Fusions relating to narrow-precision scale factor generation are demonstrated in example 72b and can only be used in blackwell kernels + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputTileMNK, ClusterShapeMNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueScheduleType, + cute::conditional_t + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileMNK, ClusterShapeMNK, + cute::conditional_t, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + StageCountType>, + MainloopScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t; + using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, float alpha, float beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementScalar(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementScalar(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{}, // epilogue.thread + block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + // See example 48 for details on custom EVT construction + if constexpr (UseCustomEVT) { + arguments.epilogue.thread = + { // ternary op : beta * C + (alpha * acc) + {{options.beta}}, // leaf op+args : beta + {}, // leaf op+args : C + { // binary op : alpha * acc + {{options.alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + // Pre-defined fusions will have flat, named args for user-friendlyness + else { + arguments.epilogue.thread.alpha = options.alpha; + arguments.epilogue.thread.beta = options.beta; + } + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + } + + return passed; + } + +}; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, bool passed) { + std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + +if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // Auto mainloop and epilogue schedules must be used together to guarantee functionality + ExampleRunner<> runner_0; + passed = runner_0.run(options, hw_info); + print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule", passed); + + // Mainloop stage counts can be specified manually + // It is the user's responsibility to ensure there is enough device smem to allocate manual stage counts + ExampleRunner< + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto, + _3> runner_1; + passed = runner_1.run(options, hw_info); + print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and 3 mainloop stages", passed); + + // 1SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules + ExampleRunner runner_2; + passed = runner_2.run(options, hw_info); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed); + + // 1SM cluster MMA mainloop schedules can also be used with 1SM TMA epilogue schedules + // 1SM cluster MMA mainloop schedules will not work with 2SM TMA epilogue schedules + ExampleRunner runner_3; + passed = runner_3.run(options, hw_info); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed); + + // 2SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules + ExampleRunner runner_4; + passed = runner_4.run(options, hw_info); + print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed); + + // 2SM cluster MMA mainloop schedules can also be used with 2SM TMA epilogue schedules + // 2SM cluster MMA mainloop schedules will not work with SM TMA epilogue schedules + ExampleRunner runner_5; + passed = runner_5.run(options, hw_info); + print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with TmaWarpSpecialized2Sm epilogue schedule", passed); + + // Blackwell Auto schedule supports custom EVT fusions + constexpr bool UseCustomEVT = true; + ExampleRunner< + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto, + cutlass::gemm::collective::StageCountAuto, + UseCustomEVT> runner_6; + passed = runner_6.run(options, hw_info); + print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and custom EVT", passed); + + // 1SM TMA epilogue schedules support custom EVT fusions + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, + cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::collective::StageCountAuto, + UseCustomEVT> runner_7; + passed = runner_7.run(options, hw_info); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with TmaWarpSpecialized1Sm epilogue and custom EVT", passed); + + // 2SM TMA epilogue schedules support custom EVT fusions + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, + cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::collective::StageCountAuto, + UseCustomEVT> runner_8; + passed = runner_8.run(options, hw_info); + print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with TmaWarpSpecialized2Sm epilogue and custom EVT", passed); + + + // Blackwell direct store epilogue schedule supports custom EVTs and named fusion operations as well (not supported for pre-Blackwell kernels) + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, + cutlass::epilogue::NoSmemWarpSpecialized, + cutlass::gemm::collective::StageCountAuto, + UseCustomEVT> runner_9; + passed = runner_9.run(options, hw_info); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue and custom EVT", passed); + +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt new file mode 100644 index 0000000000..5bac649457 --- /dev/null +++ b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt @@ -0,0 +1,35 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Both filenames are shorter to avoid MAX_PATH issues on Windows. +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 71_blackwell_gemm_with_collective_builder + 71_blackwell_gemm_with_collective_builder.cu + ) +endif() diff --git a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu new file mode 100644 index 0000000000..ec597966c4 --- /dev/null +++ b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu @@ -0,0 +1,544 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM100 architecture. + + The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced + on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma) + and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 70_blackwell_gemm, this kernel leverages: + 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + + $ ./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size +using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster +using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "72a_blackwell_nvfp4_bf16_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu new file mode 100644 index 0000000000..cefa3e920d --- /dev/null +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -0,0 +1,594 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrate a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM100 architecture + on NVIDIA Blackwell SM100 architecture. The kernel outputs quantized fp4 values with scale factors that be the input of another GEMM. + + Similar to 72a_blackwell_nvfp4_bf16_gemm, this kernel leverages: + 1. Blockscaled tcgen05.mma instructions. + + 2. Per-SM memory called Tensor Memory (TMEM) + + 3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + + $ ./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand +using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFB matrix operand +using ElementC = float; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand + +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using MmaTileShape = Shape<_128,_128,_256>; // MMA's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using PerSmTileShape_MNK = Shape<_128,_128,_256>; // Threadblock-level tile size + +constexpr int InputSFVectorSize = 16; +constexpr int OutputSFVectorSize = InputSFVectorSize; + +// D = alpha * acc + beta * C +// With BlockScaleFactor generation. +using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, LayoutSFDTag, + ElementC>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +using FusionOp = typename Gemm::EpilogueOutputOp; +constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; +using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig; +using LayoutSFD = typename SfdOutputCfg::LayoutSF; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +LayoutSFD layout_SFD; + +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensors +cutlass::HostTensor block_D; +cutlass::HostTensor block_SFD; +// Reference Output Tensors +cutlass::HostTensor block_reference_D; +cutlass::HostTensor block_reference_SFD; +// Matrix-wide normalization constant +cutlass::HostTensor block_Normconst; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "72b_blackwell_nvfp4_nvfp4_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + // For SFD tensor layout + using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_Normconst.reset(cutlass::make_Coord(1)); + + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_Normconst.at(cutlass::make_Coord(0)) = 2; + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_D.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + block_SFD.sync_device(); + block_Normconst.sync_device(); + +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + { options.alpha, options.beta }, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D} + }; + + if constexpr (IsBlockScaleSupported) { + arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data(); + arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); + } + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + // think about how to simplify the gemm3x interface. + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementCompute, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D), // TensorD + decltype(tensor_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params {options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu new file mode 100644 index 0000000000..b73f2c9428 --- /dev/null +++ b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM100 architecture. + This Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced + on the Blackwell architecture (sm100a) which have the same throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma) + and 2x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 72a_blackwell_nvfp4_fp32_gemm, this kernel leverages: + 1. Blockscaled tcgen05.mma instructions. + + 2. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + + $ ./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::mx_float4_t; // Element type for A matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size +using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster +using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "72c_blackwell_mixed_mxfp8_bf16_gemm\n\n" + << " Blackwell Mxfp8 x Mxfp4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt new file mode 100644 index 0000000000..fa80c184d5 --- /dev/null +++ b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt @@ -0,0 +1,46 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 72a_blackwell_nvfp4_bf16_gemm + 72a_blackwell_nvfp4_bf16_gemm.cu + ) + +cutlass_example_add_executable( + 72b_blackwell_nvfp4_nvfp4_gemm + 72b_blackwell_nvfp4_nvfp4_gemm.cu + ) + +cutlass_example_add_executable( + 72c_blackwell_mixed_mxfp8_bf16_gemm + 72c_blackwell_mixed_mxfp8_bf16_gemm.cu + ) +endif() diff --git a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt new file mode 100644 index 0000000000..0d0f7757cb --- /dev/null +++ b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 73_blackwell_gemm_preferred_cluster + blackwell_gemm_preferred_cluster.cu + ) +endif() diff --git a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu new file mode 100644 index 0000000000..fb62e844ff --- /dev/null +++ b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture with preferred cluster. + + With the introduction of NVIDIA Compute Capability 9.0, the CUDA programming model introduced + an optional hierarchy level known as Thread Block Clusters, which consist of multiple Thread Blocks. + While the CUDA programming model has supported the specification of cluster shapes at runtime + (Dynamic Clusters) since the Hopper architecture, CUTLASS has only provided support for Static + Clusters, meaning that cluster shapes must be defined at compile time. + + Larger cluster shapes can achieve higher TMA multicast but may result in poor SM occupancy due + to quantization. For instance, a 2x2 cluster on an 18 SM GPU would only utilize 16 SMs, leaving + 2 SMs idle. + + Starting with Compute Capability 10.0, the CUDA programming model adds the ability to specify + two clusters: preferred cluster and fallback cluster. For brevity, we refer to this as + Preferred Clusters. In the previous example, users can now launch an additional 2x1 cluster to + utilize the 2 idle SMs. + + With CUTLASS 3.8, in addition to Dynamic Clusters, CUTLASS adds support for Preferred Dynamic Cluster, + the ability for users to specify two clusters shapes at runtime. + + Terminology + * Static cluster: cluster shape is specified at compile time. + * Dynamic cluster: cluster shape is specified at runtime and set by the host. + * Preferred cluster: Kernel can be launched with two cluster shapes (preferred and fallback). + + Preferred and fallback cluster shapes are subject to several constraints. + * Preferred cluster depth (Z dimension) must be the same as that of fallback cluster. + * Fallback cluster shape must evenly divide the preferred cluster shape. + * Preferred cluster shape must evenly divide the kernel launch grid shape. + + This example demonstrates how to use the Dynamic Clusters and Preferred Clusters features in + CUTLASS 3.x Blackwell SM100 kernels. Users can specify preferred and fallback cluster shapes via GEMM arguments. + + # Example: + ./73_blackwell_gemm_preferred_cluster" --m=4096 --n=4096 --k=4096 --preferred_cluster_m=4 --preferred_cluster_n=4 --fallback_cluster_m=2 --fallback_cluster_m=1 +*/ + + + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_64>; +// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 +using AtomThrShape_MNK = Shape<_2, _1, _1>; +// Shape of the tile computed by each SM +using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); +// Shape of the cluster set to to indicate dynamic cluster shape +using ClusterShape_MNK = Shape; +// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that +// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2 +// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1> +using KernelSchedule = cute::conditional_t; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void // <--- Default to cluster launch control (CLC) scheduler +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n; + + Options(): + help(false), + m(4096), n(4096), k(4096), + alpha(1.f), beta(0.f), + iterations(10), + preferred_cluster_m(4), + preferred_cluster_n(4), + fallback_cluster_m(2), + fallback_cluster_n(1) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4); + cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4); + cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2); + cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1); + + if (!validate_cluster_shape()){ + std::cout << "--Invalid cluster shapes" << std::endl; + help = true; + return; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "73_blackwell_gemm_preferred_cluster\n\n" + << " Blackwell FP16 GEMM using preferred cluster.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --preferred_cluster_m= Sets the M extent of preferred cluster shape\n" + << " --preferred_cluster_n= Sets the N extent of preferred cluster shape\n" + << " --fallback_cluster_m= Sets the M extent of fallback cluster shape\n" + << " --fallback_cluster_n= Sets the N extent of fallback cluster shape\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "Preferred cluster shape cannot be smaller than fallback cluster shape.\n" + << "Preferred cluster shape must be a multiple of fallback cluster shape.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "73_blackwell_gemm_preferred_cluster" << " --m=4096 --n=4096 --k=4096 --preferred_cluster_m=4 --preferred_cluster_n=4 --fallback_cluster_m=2 --fallback_cluster_m=1\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + private: + /// Validate preferred and fallback cluster shapes + bool validate_cluster_shape() { + if (preferred_cluster_m < fallback_cluster_m || preferred_cluster_n < fallback_cluster_n) { + std::cout << "--Preferred cluster cannot be smaller than fallback cluster" << std::endl; + return false; + } + + if (preferred_cluster_m % fallback_cluster_m != 0 || preferred_cluster_n % fallback_cluster_n != 0) { + std::cout << "--Preferred cluster must be a multiple of fallback cluster" << std::endl; + return false; + } + return true; + } + +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block(cutlass::DeviceAllocation& block, uint64_t seed=2023) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1); + arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1); + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +int run(Options &options) { + + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << "GEMM with" + << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k + << " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)" + << " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)" + << std::endl; + + std::cout << "--------------------------------------------------------------------------------" << std::endl; + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} diff --git a/examples/74_blackwell_gemm_streamk/CMakeLists.txt b/examples/74_blackwell_gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..618561f529 --- /dev/null +++ b/examples/74_blackwell_gemm_streamk/CMakeLists.txt @@ -0,0 +1,37 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 74_blackwell_gemm_streamk + blackwell_gemm_streamk.cu + ) +endif() diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu new file mode 100644 index 0000000000..bb99fa4aff --- /dev/null +++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture with the Stream-K scheduler. + + Stream-K is a GEMM parallelization technique that attempts to reduce load imbalance across SMs + by parallelizing certain output tiles across the K mode of the GEMM, without using a static splitting factor. + For complete details on Stream-K, please see https://arxiv.org/abs/2301.03598. + + CUTLASS's Stream-K scheduler using the CUTLASS 3.x API is capable of supporting various modes of + decomposing a GEMM (referred to as "decomposition modes" in this example): + * DataParallel: basic GEMM parallelized spatially via tiling, but without splitting the K mode + * SplitK: `split_factor` CTAs compute portions of the K mode for a given output tile and reduce their results + * StreamK: parallelizes work according to the stream-K load balancing method described in https://arxiv.org/abs/2301.03598 + * Heuristic: applies an internal heuristic in attempt to choose the most performant among the three preceding decomposition modes + + Additionally, the Stream-K scheduler supports two different means of performing reductions for + decomposition modes that require reduction (SplitK, StreamK, and Heuristic): + * Deterministic: Participating CTAs perform reduction in a turnstile fashion in order of the K mode + covered by each CTA. This requires a lock to be held exclusively by the CTA that is + currently accumulating. + * Nondeterministic: Participating CTAs perform reduction atomically to the same workspace (mostly) without locking. + Locks are used only to wait for the first CTA to write its partial values (to initialize the + workspace), and for all but the final CTA to have accumulated (so that the final CTA can load + the accumulated value and accumulate it into registers on top of which the epilogue will + be performed). Due to the nondeterminsitic ordering of accumulation, deterministic numeric + behavior cannot be guaranteed with this mode (e.g., floating-point rounding error will depend + on the order of accumulation) + + This example allows one to try out different decomposition modes, reduction modes, and (when using Split-K) splitting factors. + Here are a few examples of usage: + # Heuristic mode with deterministic reduction + ./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic + + # Stream-K mode with determinsitic reduction + ./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic + + # Split-K mode with a splitting factor of 2 and deterministic reduction + ./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=SplitK --reduction=Deterministic --splits=2 + + # Stream-K mode with nondeterministic reduction + ./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Nondeterministic +*/ + + + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_64>; +// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 +using AtomThrShape_MNK = Shape<_2, _1, _1>; +// Shape of the tile computed by each SM +using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); +// Shape of the cluster set to to indicate dynamic cluster shape +using ClusterShape_MNK = Shape; +// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that +// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2 +// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1> +using KernelSchedule = cute::conditional_t; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler // <--- Change needed to enable the stream-K scheduler +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n; + using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; + DecompositionMode decomposition_mode; + ReductionMode reduction_mode; + int splits; + + std::unordered_map> dec_mappings = { + {DecompositionMode::Heuristic, {"Heuristic", "heuristic", "h", "H", ""}}, + {DecompositionMode::SplitK, {"SplitK", "split-k", "split-K", "Split-K", "Split-k", "splitk", "Splitk", "splitK", "spk", "SpK", "spK"}}, + {DecompositionMode::StreamK, {"StreamK", "stream-k", "stream-K", "Stream-K", "Stream-k", "streamk", "Streamk", "streamK", "stk", "StK", "stK"}}, + {DecompositionMode::DataParallel, {"DataParallel", "data-parallel", "dataparallel", "dp", "DP"}} + }; + + std::unordered_map> red_mappings = { + {ReductionMode::Deterministic, {"Deterministic", "deterministic", "d", "D", ""}}, + {ReductionMode::Nondeterministic, {"Nondeterministic", "nondeterministic", "n", "N"}} + }; + + Options(): + help(false), + m(256), n(256), k(16384), + alpha(1.f), beta(0.f), + iterations(10), + preferred_cluster_m(4), + preferred_cluster_n(4), + fallback_cluster_m(2), + fallback_cluster_n(1), + decomposition_mode(DecompositionMode::Heuristic), + reduction_mode(ReductionMode::Deterministic), + splits(1) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("splits", splits, 1); + cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4); + cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4); + cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2); + cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1); + + // Parse decompsition mode + std::string decomp_mode; + cmd.get_cmd_line_argument("decomposition", decomp_mode); + bool found = parse_from_options_map(decomp_mode, dec_mappings, decomposition_mode); + if (!found) { + std::cout << "--decomposition must be one of Heuristic, SplitK, StreamK, or DataParallel" << std::endl; + help = true; + return; + } + + // Parse reduction mode + std::string red_mode; + cmd.get_cmd_line_argument("reduction", red_mode); + found = parse_from_options_map(red_mode, red_mappings, reduction_mode); + if (!found) { + std::cout << "--reduction must be one of Deterministic and Nondeterministic" << std::endl; + help = true; + return; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "74_blackwell_gemm_streamk\n\n" + << " Blackwell FP16 GEMM using a stream-K kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --preferred_cluster_m= Sets the M extent of preferred cluster shape\n" + << " --preferred_cluster_n= Sets the N extent of preferred cluster shape\n" + << " --fallback_cluster_m= Sets the M extent of fallback cluster shape\n" + << " --fallback_cluster_n= Sets the N extent of fallback cluster shape\n" + << " --decomposition= Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n" + << " --reduction= Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "74_blackwell_gemm_streamk" << " --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + std::string decomposition_mode_str() const { + return dec_mappings.at(decomposition_mode).at(0); + } + + std::string reduction_mode_str() const { + return red_mappings.at(reduction_mode).at(0); + } + + private: + template + bool parse_from_options_map(std::string val, std::unordered_map> options, T& result) const { + for (const auto & [key, values] : options) { + if (std::find(values.begin(), values.end(), val) != values.end()) { + result = key; + return true; + } + } + return false; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block(cutlass::DeviceAllocation& block, uint64_t seed=2023) { + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1); + arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1); + + arguments.scheduler.splits = options.splits; + arguments.scheduler.decomposition_mode = options.decomposition_mode; + arguments.scheduler.reduction_mode = options.reduction_mode; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +int run(Options &options) { + + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << "Stream-K GEMM with" + << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k + << " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)" + << " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n" + << " Decomposition_mode=" << options.decomposition_mode_str() + << " Split_count=" << options.splits + << " Reduction_mode=" << options.reduction_mode_str() + << std::endl; + + std::cout << "--------------------------------------------------------------------------------" << std::endl; + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu new file mode 100644 index 0000000000..520d8ceef9 --- /dev/null +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu @@ -0,0 +1,813 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel. + For this example all scheduling work is performed on the device. + The new feature showcased in this example is device-side modification of TMA descriptors + to move between groups/problem_count (represented by groups). + https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device + + To run this example: + + $ ./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ElementC = cutlass::half_t; // Element type for C and D matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Runtime Cluster Shape +using ClusterShape = Shape; +// For Static Cluster Shape: +// using ClusterShape = Shape<_2,_1,_1>; // for example +// using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config +// using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder +// using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder + +// Different configs for 1SM and 2SM MMA kernel +struct MMA1SMConfig { + using MmaTileShape = Shape<_128,_256,Int<128 / sizeof(ElementA)>>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{})); +}; + +struct MMA2SMConfig { + using MmaTileShape = Shape<_256,_256,Int<128 / sizeof(ElementA)>>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{})); +}; + +template +struct GivenGemmSchedule { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + typename ScheduleConfig::OutputTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + typename ScheduleConfig::EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + typename ScheduleConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel1SM = GivenGemmSchedule::GemmKernel; +using Gemm1SM = GivenGemmSchedule::Gemm; +using Gemm = Gemm1SM; + +using GemmKernel2SM = GivenGemmSchedule::GemmKernel; +using Gemm2SM = GivenGemmSchedule::Gemm; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = FLT_MAX; + float beta = FLT_MAX; + int iterations = 10; + int m = 1024, n = 2048, k = 512, groups = 10; + dim3 cluster_shape = dim3(4,2,1); + dim3 cluster_shape_fallback = dim3(2,1,1); + RasterOrderOptions raster_order = RasterOrderOptions::AlongM; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("cluster_m", cluster_shape.x); + cmd.get_cmd_line_argument("cluster_n", cluster_shape.y); + cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x); + cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (n < 1) { + n = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "75_blackwell_grouped_gemm\n\n" + << " Blackwell FP8 Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --cluster_m= and --cluster_n= Sets the X,Y dims of the preferred cluster shape\n" + << " --cluster_fallback_m= and --cluster_fallback_n= Sets the X,Y dims of the fallback cluster shape\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "75_blackwell_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + if (!is_static_v) { + if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 && + (options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) { + std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl; + } + hw_info.cluster_shape = options.cluster_shape; + hw_info.cluster_shape_fallback = options.cluster_shape_fallback; + } + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, then same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, then alpha/beta can differ between batches/groups. + if (options.alpha != FLT_MAX){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + if (options.beta != FLT_MAX) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {M, N, K}, + ElementAccumulator(alpha_host.at(i)), + ref_A, + ref_B, + ElementAccumulator(beta_host.at(i)), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || + ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8) + ) + ) { + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 10 && props.minor == 0)) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "Running kernel with 1SM MMA config:" << std::endl; + run(options, false /*host_problem_shapes_available*/); + std::cout << "Running kernel with 2SM MMA config:" << std::endl; + run(options, false /*host_problem_shapes_available*/); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu new file mode 100644 index 0000000000..1486a3c65e --- /dev/null +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -0,0 +1,953 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel + for narrow precisions (FP4) with Scale Factors (In and Out). + For this example all scheduling work is performed on the device. + The new feature showcased in this example is device-side modification of TMA descriptors + to move between groups/problem_count (represented by groups). + https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device + + To run this example: + + $ ./examples/75_blackwell_grouped_gemm_block_scaled/75_blackwell_grouped_gemm_block_scaled --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/75_blackwell_grouped_gemm_block_scaled/75_blackwell_grouped_gemm_block_scaled --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands +using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands +using ElementC = cutlass::half_t; // Element type for C matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = ElementC; // Element type for D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) +using ElementAccumulator = float; // Element type for internal accumulation + +// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands +constexpr int OutputSFVectorSize = 16; +using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< + cutlass::epilogue::thread::SiLu, + OutputSFVectorSize, + ElementD, + ElementAccumulator, + ElementSF, + LayoutC, + ElementC>; + +// Core kernel configurations +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag +using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Runtime Cluster Shape +using ClusterShape = Shape; +/* // For Static Cluster Shape: +use ClusterShape = Shape<_2,_1,_1> for example +using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config +using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder +using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder +*/ + +// Different configs for 1SM and 2SM MMA kernel +struct MMA1SMConfig { + using MmaTileShape = Shape<_128,_256,_256>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{})); +}; + +struct MMA2SMConfig { + using MmaTileShape = Shape<_256,_256,_256>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{})); +}; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, EpilogueOperatorClass, + typename MMA1SMConfig::OutputTileShape, ClusterShape, + Shape<_128,_64>, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule + // , FusionOperation // Enable for SF Output +>::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, MainloopOperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule +>::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; +using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; +using Gemm = Gemm1SM; + +using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, EpilogueOperatorClass, + typename MMA2SMConfig::OutputTileShape, ClusterShape, + Shape<_128,_64>, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + typename MMA2SMConfig::EpilogueSchedule + // , FusionOperation // Enable for SF Output +>::CollectiveOp; +using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, MainloopOperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + typename MMA2SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA2SMConfig::KernelSchedule +>::CollectiveOp; +using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop2SM, + CollectiveEpilogue2SM +>; +using Gemm2SM = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; +using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; +using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig< + OutputSFVectorSize, + cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN + >; +using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom; +using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF; + +// Host-side allocations +std::vector stride_A_host; +std::vector stride_B_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +using HostTensorA = cutlass::HostTensor; +using HostTensorB = cutlass::HostTensor; +using HostTensorSF = cutlass::HostTensor; +using HostTensorC = cutlass::HostTensor; +using HostTensorD = cutlass::HostTensor; +std::vector block_A; +std::vector block_B; +std::vector block_SFA; +std::vector block_SFB; +std::vector block_C; +std::vector block_D; +std::vector block_SFD; +std::vector block_ref_D; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFD; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +// A matrix wide constant value to scale the output matrix +// Avoids generating small FP4 values. +// NormConst is a single device-side constant value, its not per-batch or per-group +cutlass::DeviceAllocation norm_constant_device; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool verification = true; + + float alpha = FLT_MAX; + float beta = FLT_MAX; + float norm_constant = 1.0; + int iterations = 10; + int m = 1024, n = 2048, k = 512, groups = 10; + dim3 cluster_shape = dim3(2,1,1); + dim3 cluster_shape_fallback = dim3(2,1,1); + RasterOrderOptions raster_order = RasterOrderOptions::AlongN; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + if (cmd.check_cmd_line_flag("no-verif")) { + verification = false; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); + cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0)); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("cluster_m", cluster_shape.x); + cmd.get_cmd_line_argument("cluster_n", cluster_shape.y); + cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x); + cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (n < 1) { + n = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "75_blackwell_grouped_gemm_block_scaled\n\n" + << " Blackwell Block Scaled Narrow Precision Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --norm_constant= Epilogue scalar normalization constant for the output matrix\n\n" + << " --cluster_m= and --cluster_n= Sets the X,Y dims of the preferred cluster shape\n" + << " --cluster_fallback_m= and --cluster_fallback_n= Sets the X,Y dims of the fallback cluster shape\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --no-verif Do not run (host-side) verification kernels\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "75_blackwell_grouped_gemm_block_scaled" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + stride_A_host.push_back(stride_A); + stride_B_host.push_back(stride_B); + layout_SFA_host.push_back(layout_SFA); + layout_SFB_host.push_back(layout_SFB); + stride_C_host.push_back(stride_C); + stride_D_host.push_back(stride_D); + + block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A)))); + block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B)))); + block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA))))); + block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB))))); + block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C)))); + block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + } + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + uint64_t seed = 2020; + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_SFA_host(options.groups); + std::vector ptr_SFB_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_SFD_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + + initialize_block(block_A.at(i).host_view(), seed + 2021); + initialize_block(block_B.at(i).host_view(), seed + 2022); + initialize_block(block_C.at(i).host_view(), seed + 2023); + initialize_block(block_SFA.at(i).host_view(), seed + 2024); + initialize_block(block_SFB.at(i).host_view(), seed + 2025); + + block_A.at(i).sync_device(); + block_B.at(i).sync_device(); + block_C.at(i).sync_device(); + block_SFA.at(i).sync_device(); + block_SFB.at(i).sync_device(); + + ptr_A_host.at(i) = block_A.at(i).device_data(); + ptr_B_host.at(i) = block_B.at(i).device_data(); + ptr_SFA_host.at(i) = block_SFA.at(i).device_data(); + ptr_SFB_host.at(i) = block_SFB.at(i).device_data(); + ptr_C_host.at(i) = block_C.at(i).device_data(); + ptr_D_host.at(i) = block_D.at(i).device_data(); + ptr_SFD_host.at(i) = block_SFD.at(i).device_data(); + + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(ptr_SFB_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_SFD.reset(options.groups); + ptr_SFD.copy_from_host(ptr_SFD_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + norm_constant_device.reset(1); + norm_constant_device.copy_from_host(&options.norm_constant); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + if (!is_static_v) { + if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 && + (options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) { + std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl; + } + hw_info.cluster_shape = options.cluster_shape; + hw_info.cluster_shape_fallback = options.cluster_shape_fallback; + } + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + if (options.alpha != FLT_MAX){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + if (options.beta != FLT_MAX) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + // Output Block SF + // fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output + // fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB); + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); + auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); + + cutlass::reference::host::GettEpilogueParams< + float, float, + ElementAccumulator, ElementAccumulator, + decltype(tensor_C), decltype(tensor_ref_D) + > epilogue_params{}; + + epilogue_params.C = tensor_C; + epilogue_params.D = tensor_ref_D; + epilogue_params.alpha = alpha_host.at(i); + epilogue_params.beta = beta_host.at(i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + block_D.at(i).sync_host(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view()); + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + if (options.verification) { + std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + } + else { + std::cout << " Verfication is turned off for this run." << std::endl; + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || + ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8) + ) + ) { + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 10 && props.minor == 0)) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "Running kernel with 1SM MMA config:" << std::endl; + run(options, false /*host_problem_shapes_available*/); + std::cout << "Running kernel with 2SM MMA config:" << std::endl; + run(options, false /*host_problem_shapes_available*/); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/75_blackwell_grouped_gemm/CMakeLists.txt b/examples/75_blackwell_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000..2da2d4c43b --- /dev/null +++ b/examples/75_blackwell_grouped_gemm/CMakeLists.txt @@ -0,0 +1,88 @@ +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + + + +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=51 --iterations=0) # Fixed problem sizes + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) # Small problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes +set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 75_blackwell_grouped_gemm + 75_blackwell_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + ) + +cutlass_example_add_executable( + 75_blackwell_grouped_gemm_block_scaled + 75_blackwell_grouped_gemm_block_scaled.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + ) +endif() diff --git a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu new file mode 100644 index 0000000000..daadcd56a0 --- /dev/null +++ b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple dgrad convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. + + This example demonstrate a simple way to instantiate and run a dgrad convolution kernel using the new CUTLASS 3.0 + APIs on NVIDIA Blackwell SM100 architecture. + + The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example: + Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC) + + where in terms of GEMM perspective, + Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation + + This example instantiates a simple dgrad kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16. + Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example. + + Usage: + + $ ./examples/76_blackwell_conv/76_blackwell_conv_dgrad --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0 + --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Activation matrix configuration +using ElementAct = half_t; // Element type for activation matrix +constexpr int AlignmentAct = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes) + +// Weight/Filter matrix configuration +using ElementFlt = half_t; // Element type for weight/filter matrix operand +constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes) + +// Xformed activation matrix configuration +using ElementXformedAct = half_t; // Element type for xformed activation matrix operand +constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes) + +// Layout of matrix A/B/C in gemm's perspecitive. +using LayoutA = cutlass::layout::TensorNDHWC; +using LayoutB = cutlass::layout::TensorNDHWC; +using LayoutC = cutlass::layout::TensorNDHWC; + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal computation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kDgrad; // Convolution operation + +// Kernel Perf config +using TileShape = Shape<_128,_128,Shape<_64>>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementAct, LayoutC, AlignmentAct, + ElementAct, LayoutC, AlignmentAct, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + ArchTag, OperatorClass, ConvOp, + ElementXformedAct, LayoutA, AlignmentXformedAct, + ElementFlt, LayoutB, AlignmentFlt, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + +// Compose into a kernel +using ProblemShape=cutlass::conv::ConvProblemShape; +using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + +using Conv = cutlass::conv::device::ConvUniversalAdapter; + +using StrideC = typename Conv::ConvKernel::StrideC; +using StrideD = typename Conv::ConvKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int n, d, h, w, c, k, t, r, s, z, p, q; + int pad_d, pad_h, pad_w; + int stride_d, stride_h, stride_w; + int dilation_d, dilation_h, dilation_w; + + Options(): + help(false), + n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3), + pad_d(0), pad_h(1), pad_w(1), + stride_d(1), stride_h(1), stride_w(1), + dilation_d(1), dilation_h(1), dilation_w(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("d", d); + cmd.get_cmd_line_argument("h", h); + cmd.get_cmd_line_argument("w", w); + cmd.get_cmd_line_argument("c", c); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("t", t); + cmd.get_cmd_line_argument("r", r); + cmd.get_cmd_line_argument("s", s); + cmd.get_cmd_line_argument("pad_d", pad_d); + cmd.get_cmd_line_argument("pad_h", pad_h); + cmd.get_cmd_line_argument("pad_w", pad_w); + cmd.get_cmd_line_argument("stride_d", stride_d); + cmd.get_cmd_line_argument("stride_h", stride_h); + cmd.get_cmd_line_argument("stride_w", stride_w); + cmd.get_cmd_line_argument("dilation_d", dilation_d); + cmd.get_cmd_line_argument("dilation_h", dilation_h); + cmd.get_cmd_line_argument("dilation_w", dilation_w); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + // Calculate z,p,q based on inputs. + z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d; + p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h; + q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "76_blackwell_conv_dgrad\n\n" + << " Blackwell FP16 dgrad convolution using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --n= Sets the batch size of the Activation\n" + << " --d= Sets the depth size of the Activation\n" + << " --h= Sets the height of the Activation\n" + << " --w= Sets the width of the Activation\n" + << " --c= Sets the channel size of the Activation\n" + << " --k= Sets the image numbers of the Filter\n" + << " --t= Sets the depth size of the Filter\n" + << " --r= Sets the height of the Filter\n" + << " --s= Sets the width of the Filter\n" + << " --pad_d= Sets the padding size in depth\n" + << " --pad_h= Sets the padding size in height\n" + << " --pad_w= Sets the padding size in width\n" + << " --stride_d= Sets the traversal stride size in depth\n" + << " --stride_h= Sets the traversal stride size in height\n" + << " --stride_w= Sets the traversal stride size in width\n" + << " --dialtion_d= Sets the filter dilation size in depth\n" + << " --dialtion_h= Sets the filter dilation size in height\n" + << " --dialtion_w= Sets the filter dilation size in width\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "76_blackwell_conv_dgrad" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0" + << " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * (n * d * h * w) * c * (t * r * s * k); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the Conv and reference Conv +void initialize(const Options &options) { + + // Construct ConvProblemShape + ProblemShape problem_shape( + cutlass::conv::Mode::kCrossCorrelation, + {options.n, options.d, options.h, options.w, options.c}, // ndhwc + {options.k, options.t, options.r, options.s, options.c}, // ktrsc + {options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w) + {options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w) + {options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w) + {options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + ); + + // Setup stride_C/D + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { + cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { + cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + + block_A.reset(problem_shape.size_A()); + block_B.reset(problem_shape.size_B()); + block_C.reset(problem_shape.size_C()); + block_D.reset(problem_shape.size_C()); + block_ref_D.reset(problem_shape.size_C()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Conv::Arguments args_from_options(const Options &options) +{ + // Construct ConvProblemShape + ProblemShape problem_shape( + cutlass::conv::Mode::kCrossCorrelation, + {options.n, options.d, options.h, options.w, options.c}, // ndhwc + {options.k, options.t, options.r, options.s, options.c}, // ktrsc + {options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w) + {options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w) + {options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w) + {options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + ); + + typename Conv::Arguments arguments{ + problem_shape, + {block_A.get(), block_B.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.z, options.p, options.q, options.k})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.t, options.r, options.s, options.c})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.n, options.d, options.h, options.w, options.c})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.n, options.d, options.h, options.w, options.c})); + + // + // Compute reference output + // + + // Construct Conv3dProblemSize with user defined inputs. + cutlass::conv::Conv3dProblemSize problem_size( + cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc + cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc + cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding + cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w) + cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w) + cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk + ); + + // Launch device reference conv kernel + cutlass::reference::device::Conv3dDgrad(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Conv conv; + + // Create a structure of conv kernel arguments suitable for invoking an instance of Conv + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Conv::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(conv.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(conv.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(conv.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(conv.initialize(arguments, workspace.get())); + CUTLASS_CHECK(conv.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size:" << std::endl; + std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), "; + std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), "; + std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu new file mode 100644 index 0000000000..8598637e15 --- /dev/null +++ b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple fprop convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. + + This example demonstrate a simple way to instantiate and run a fprop convolution kernel using the new CUTLASS 3.0 + APIs on NVIDIA Blackwell SM100 architecture. + + The basic computation logic of fprop convolution kernel is, take 3D convolution as an example: + Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK) + + where in terms of GEMM perspective, + Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation + + This example instantiates a simple fprop kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16. + Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example. + + Usage: + + $ ./examples/76_blackwell_conv/76_blackwell_conv_fprop --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0 + --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Activation matrix configuration +using ElementAct = half_t; // Element type for activation matrix +constexpr int AlignmentAct = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes) + +// Weight/Filter matrix configuration +using ElementFlt = half_t; // Element type for weight/filter matrix operand +constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes) + +// Xformed activation matrix configuration +using ElementXformedAct = half_t; // Element type for xformed activation matrix operand +constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes) + +// Layout of matrix A/B/C in gemm's perspecitive. +using LayoutA = cutlass::layout::TensorNDHWC; +using LayoutB = cutlass::layout::TensorNDHWC; +using LayoutC = cutlass::layout::TensorNDHWC; + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal computation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kFprop; // Convolution operation + +// Kernel Perf config +using TileShape = Shape<_128,_128,Shape<_64>>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementXformedAct, LayoutC, AlignmentXformedAct, + ElementXformedAct, LayoutC, AlignmentXformedAct, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + ArchTag, OperatorClass, ConvOp, + ElementAct, LayoutA, AlignmentAct, + ElementFlt, LayoutB, AlignmentFlt, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + +// Compose into a kernel +using ProblemShape=cutlass::conv::ConvProblemShape; +using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + +using Conv = cutlass::conv::device::ConvUniversalAdapter; + +using StrideC = typename Conv::ConvKernel::StrideC; +using StrideD = typename Conv::ConvKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int n, d, h, w, c, k, t, r, s, z, p, q; + int pad_d, pad_h, pad_w; + int stride_d, stride_h, stride_w; + int dilation_d, dilation_h, dilation_w; + + Options(): + help(false), + n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3), + pad_d(0), pad_h(1), pad_w(1), + stride_d(1), stride_h(1), stride_w(1), + dilation_d(1), dilation_h(1), dilation_w(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("d", d); + cmd.get_cmd_line_argument("h", h); + cmd.get_cmd_line_argument("w", w); + cmd.get_cmd_line_argument("c", c); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("t", t); + cmd.get_cmd_line_argument("r", r); + cmd.get_cmd_line_argument("s", s); + cmd.get_cmd_line_argument("pad_d", pad_d); + cmd.get_cmd_line_argument("pad_h", pad_h); + cmd.get_cmd_line_argument("pad_w", pad_w); + cmd.get_cmd_line_argument("stride_d", stride_d); + cmd.get_cmd_line_argument("stride_h", stride_h); + cmd.get_cmd_line_argument("stride_w", stride_w); + cmd.get_cmd_line_argument("dilation_d", dilation_d); + cmd.get_cmd_line_argument("dilation_h", dilation_h); + cmd.get_cmd_line_argument("dilation_w", dilation_w); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + // Calculate z,p,q based on inputs. + z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d; + p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h; + q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "76_blackwell_conv_fprop\n\n" + << " Blackwell FP16 fprop convolution using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --n= Sets the batch size of the Activation\n" + << " --d= Sets the depth size of the Activation\n" + << " --h= Sets the height of the Activation\n" + << " --w= Sets the width of the Activation\n" + << " --c= Sets the channel size of the Activation\n" + << " --k= Sets the image numbers of the Filter\n" + << " --t= Sets the depth size of the Filter\n" + << " --r= Sets the height of the Filter\n" + << " --s= Sets the width of the Filter\n" + << " --pad_d= Sets the padding size in depth\n" + << " --pad_h= Sets the padding size in height\n" + << " --pad_w= Sets the padding size in width\n" + << " --stride_d= Sets the traversal stride size in depth\n" + << " --stride_h= Sets the traversal stride size in height\n" + << " --stride_w= Sets the traversal stride size in width\n" + << " --dialtion_d= Sets the filter dilation size in depth\n" + << " --dialtion_h= Sets the filter dilation size in height\n" + << " --dialtion_w= Sets the filter dilation size in width\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "76_blackwell_conv_fprop" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0" + << " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * (n * z * p * q) * k * (t * r * s * c); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the Conv and reference Conv +void initialize(const Options &options) { + + // Construct ConvProblemShape + ProblemShape problem_shape( + cutlass::conv::Mode::kCrossCorrelation, + {options.n, options.d, options.h, options.w, options.c}, // ndhwc + {options.k, options.t, options.r, options.s, options.c}, // ktrsc + {options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w) + {options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w) + {options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w) + {options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + ); + + // Setup stride_C/D + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { + cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { + cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + + block_A.reset(problem_shape.size_A()); + block_B.reset(problem_shape.size_B()); + block_C.reset(problem_shape.size_C()); + block_D.reset(problem_shape.size_C()); + block_ref_D.reset(problem_shape.size_C()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Conv::Arguments args_from_options(const Options &options) +{ + // Construct ConvProblemShape + ProblemShape problem_shape( + cutlass::conv::Mode::kCrossCorrelation, + {options.n, options.d, options.h, options.w, options.c}, // ndhwc + {options.k, options.t, options.r, options.s, options.c}, // ktrsc + {options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w) + {options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w) + {options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w) + {options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + ); + + typename Conv::Arguments arguments{ + problem_shape, + {block_A.get(), block_B.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.d, options.h, options.w, options.c})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.t, options.r, options.s, options.c})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.n, options.z, options.p, options.q, options.k})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.n, options.z, options.p, options.q, options.k})); + + // + // Compute reference output + // + + // Construct Conv3dProblemSize with user defined inputs. + cutlass::conv::Conv3dProblemSize problem_size( + cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc + cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc + cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding + cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w) + cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w) + cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk + ); + + // Launch device reference conv kernel + cutlass::reference::device::Conv3dFprop(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Conv conv; + + // Create a structure of conv kernel arguments suitable for invoking an instance of Conv + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Conv::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(conv.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(conv.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(conv.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(conv.initialize(arguments, workspace.get())); + CUTLASS_CHECK(conv.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size:" << std::endl; + std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), "; + std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), "; + std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu new file mode 100644 index 0000000000..d99cdacc97 --- /dev/null +++ b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu @@ -0,0 +1,530 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple wgrad convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. + + This example demonstrate a simple way to instantiate and run a wgrad convolution kernel using the new CUTLASS 3.0 + APIs on NVIDIA Blackwell SM100 architecture. + + The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example: + Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC) + + where in terms of GEMM perspective, + Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter + + This example instantiates a simple wgrad kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16. + Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example. + + Usage: + + $ ./examples/76_blackwell_conv/76_blackwell_conv_wgrad --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0 + --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Activation matrix configuration +using ElementAct = half_t; // Element type for activation matrix +constexpr int AlignmentAct = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes) + +// Weight/Filter matrix configuration +using ElementFlt = half_t; // Element type for weight/filter matrix operand +constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes) + +// Xformed activation matrix configuration +using ElementXformedAct = half_t; // Element type for xformed activation matrix operand +constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes) + +// Layout of matrix A/B/C in gemm's perspecitive. +using LayoutA = cutlass::layout::TensorNDHWC; +using LayoutB = cutlass::layout::TensorNDHWC; +using LayoutC = cutlass::layout::TensorKCSRT; + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal computation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kWgrad; // Convolution operation + +// Kernel Perf config +using TileShape = Shape<_128,Shape<_128>,Shape<_64>>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementFlt, LayoutC, AlignmentFlt, + ElementFlt, LayoutC, AlignmentFlt, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + ArchTag, OperatorClass, ConvOp, + ElementXformedAct, LayoutA, AlignmentXformedAct, + ElementAct, LayoutB, AlignmentAct, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + +// Compose into a kernel +using ProblemShape=cutlass::conv::ConvProblemShape; +using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + +using Conv = cutlass::conv::device::ConvUniversalAdapter; + +using StrideC = typename Conv::ConvKernel::StrideC; +using StrideD = typename Conv::ConvKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int n, d, h, w, c, k, t, r, s, z, p, q; + int pad_d, pad_h, pad_w; + int stride_d, stride_h, stride_w; + int dilation_d, dilation_h, dilation_w; + + Options(): + help(false), + n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3), + pad_d(0), pad_h(1), pad_w(1), + stride_d(1), stride_h(1), stride_w(1), + dilation_d(1), dilation_h(1), dilation_w(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("d", d); + cmd.get_cmd_line_argument("h", h); + cmd.get_cmd_line_argument("w", w); + cmd.get_cmd_line_argument("c", c); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("t", t); + cmd.get_cmd_line_argument("r", r); + cmd.get_cmd_line_argument("s", s); + cmd.get_cmd_line_argument("pad_d", pad_d); + cmd.get_cmd_line_argument("pad_h", pad_h); + cmd.get_cmd_line_argument("pad_w", pad_w); + cmd.get_cmd_line_argument("stride_d", stride_d); + cmd.get_cmd_line_argument("stride_h", stride_h); + cmd.get_cmd_line_argument("stride_w", stride_w); + cmd.get_cmd_line_argument("dilation_d", dilation_d); + cmd.get_cmd_line_argument("dilation_h", dilation_h); + cmd.get_cmd_line_argument("dilation_w", dilation_w); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + // Calculate z,p,q based on inputs. + z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d; + p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h; + q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "76_blackwell_conv_wgrad\n\n" + << " Blackwell FP16 wgrad convolution using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --n= Sets the batch size of the Activation\n" + << " --d= Sets the depth size of the Activation\n" + << " --h= Sets the height of the Activation\n" + << " --w= Sets the width of the Activation\n" + << " --c= Sets the channel size of the Activation\n" + << " --k= Sets the image numbers of the Filter\n" + << " --t= Sets the depth size of the Filter\n" + << " --r= Sets the height of the Filter\n" + << " --s= Sets the width of the Filter\n" + << " --pad_d= Sets the padding size in depth\n" + << " --pad_h= Sets the padding size in height\n" + << " --pad_w= Sets the padding size in width\n" + << " --stride_d= Sets the traversal stride size in depth\n" + << " --stride_h= Sets the traversal stride size in height\n" + << " --stride_w= Sets the traversal stride size in width\n" + << " --dialtion_d= Sets the filter dilation size in depth\n" + << " --dialtion_h= Sets the filter dilation size in height\n" + << " --dialtion_w= Sets the filter dilation size in width\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "76_blackwell_conv_wgrad" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0" + << " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * k * (t * r * s * c) * (n * z * p * q); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the Conv and reference Conv +void initialize(const Options &options) { + + // Construct ConvProblemShape + ProblemShape problem_shape( + cutlass::conv::Mode::kCrossCorrelation, + {options.n, options.d, options.h, options.w, options.c}, // ndhwc + {options.k, options.t, options.r, options.s, options.c}, // ktrsc + {options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w) + {options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w) + {options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w) + {options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + ); + + // Setup stride_C/D + stride_C = cutlass::make_cute_packed_stride(StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + + block_A.reset(problem_shape.size_A()); + block_B.reset(problem_shape.size_B()); + block_C.reset(problem_shape.size_C()); + block_D.reset(problem_shape.size_C()); + block_ref_D.reset(problem_shape.size_C()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Conv::Arguments args_from_options(const Options &options) +{ + // Construct ConvProblemShape + ProblemShape problem_shape( + cutlass::conv::Mode::kCrossCorrelation, + {options.n, options.d, options.h, options.w, options.c}, // ndhwc + {options.k, options.t, options.r, options.s, options.c}, // ktrsc + {options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w) + {options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w) + {options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w) + {options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + ); + + typename Conv::Arguments arguments{ + problem_shape, + {block_A.get(), block_B.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.z, options.p, options.q, options.k})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.n, options.d, options.h, options.w, options.c})); + cutlass::TensorRef ref_C(block_C.get(), LayoutA::packed({options.k, options.t, options.r, options.s, options.c})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutA::packed({options.k, options.t, options.r, options.s, options.c})); + + // + // Compute reference output + // + + // Construct Conv3dProblemSize with user defined inputs. + cutlass::conv::Conv3dProblemSize problem_size( + cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc + cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc + cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding + cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w) + cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w) + cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk + ); + + // Launch device reference conv kernel + cutlass::reference::device::Conv3dWgrad(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Conv conv; + + // Create a structure of conv kernel arguments suitable for invoking an instance of Conv + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Conv::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(conv.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(conv.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(conv.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(conv.initialize(arguments, workspace.get())); + CUTLASS_CHECK(conv.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size:" << std::endl; + std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), "; + std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), "; + std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/76_blackwell_conv/CMakeLists.txt b/examples/76_blackwell_conv/CMakeLists.txt new file mode 100644 index 0000000000..8d31d7433f --- /dev/null +++ b/examples/76_blackwell_conv/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 76_blackwell_conv_fprop + 76_blackwell_conv_fprop.cu +) + +cutlass_example_add_executable( + 76_blackwell_conv_dgrad + 76_blackwell_conv_dgrad.cu +) + +cutlass_example_add_executable( + 76_blackwell_conv_wgrad + 76_blackwell_conv_wgrad.cu +) +endif() diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu new file mode 100644 index 0000000000..1d1314d145 --- /dev/null +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -0,0 +1,990 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example implementation of fused multi-head attention for the NVIDIA Blackwell SM100 + architecture using CUTLASS 3. + + MQA/GQA + ------- + + The head dimension can be represented as a tuple, where the K/V strides in the + first dimension is zero. This has the effect of MQA or GQA. + * MHA is (head_size:head_stride). + * MQA is (head_size:head_stride) in Q and (head_size:_0) in K and V. + * GQA is (grouped_heads,heads_kv):(head_stride,grouped_heads*head_stride) in Q + and (grouped_heads,heads_kv):(0,head_stride) in K and V + + Output Scale + ------------ + + The output scale gets passed to the collective mainloop, and is applied + using FP32 compute pre-quantization + + Variable Sequence Length + ------------------------ + + For variable sequence length, pass in VariableLength objects + (max_seqlen, cumulative_seqlen_ptr) in the problem shape for + seqlen Q and KV. + + Support + --------- + + Right now e4m3 with fp32 compute is using a 256x256 tiling and a head dimension + of 128 is supported. + + + Example usage: + $ ./examples/77_blackell_fmha/77_blackell_fmha_fp8 \ + --b=2048 --h=2048 --d=2048 --q=2048 --k=2048 +*/ + +#define DSHOW(x) print(#x ": "); print(x); print("\n"); +#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n"); + +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "reference/fmha_fwd_reference.hpp" +#include "reference/reference_abs_error.hpp" + +#include "device/fmha.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp" +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class InitStyle { + kOne, kLinearStride128, kLinearStride1, kRandom, kNone +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help = false; + bool error = false; + + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int d = 128; + int iterations = 3; + bool verify = false; + bool verbose = false; + + bool causal = false; + bool residual = false; + bool varlen = false; + int sm_count = 0; + + std::string kernel_filter; + + InitStyle init_style_q = InitStyle::kRandom; + InitStyle init_style_k = InitStyle::kRandom; + InitStyle init_style_v = InitStyle::kRandom; + + static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) { + std::string s; + cmd.get_cmd_line_argument(name, s, s); + if (s.empty()) { + dst = src; + } + else { + if (s == "r") { + dst = InitStyle::kRandom; + } + else if (s == "1") { + dst = InitStyle::kOne; + } + else if (s == "d") { + dst = InitStyle::kLinearStride1; + } + else if (s == "s") { + dst = InitStyle::kLinearStride128; + } + else if (s == "n") { + dst = InitStyle::kNone; + } + else { + std::cout << "Error: " << s << " is not a valid input type.\n"; + std::exit(-1); + } + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("d", d, defaults.d); + cmd.get_cmd_line_argument("h", h, -1); + if (h == -1) h = 2048 / d; + + cmd.get_cmd_line_argument("h_k", h_k, -1); + if (h_k == -1) h_k = h; + + cmd.get_cmd_line_argument("q", q, -1); + cmd.get_cmd_line_argument("k", k, -1); + if (q == -1) q = k; + if (k == -1) k = q; + if (q == -1 && k == -1) q = k = defaults.q; + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + varlen = cmd.check_cmd_line_flag("varlen"); + std::string mask; + cmd.get_cmd_line_argument("mask", mask, ""); + if (mask == "no" || mask == "") { + causal = residual = false; + if (varlen) { + residual = true; + } + } + else if (mask == "causal") { + residual = false; + causal = true; + } + else if (mask == "residual") { + residual = true; + causal = false; + } + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); + + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q); + get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q); + get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k); + get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v); + + cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "77_blackwell_fmha\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " fused multi-head attention forward-passkernels targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --h= Sets the H extent\n" + << " --h_k= Sets the H_K/V extent (for GQA/MQA)\n" + << " --q= Sets the Q extent\n" + << " --k= Sets the K extent\n" + << " --d= Sets the D extentn" + << " --iterations= Benchmarking iterations\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --mask= Enables masking\n" + << " --varlen Enables variable sequence length\n" + << " B*Q and B*K become the total sequence length\n" + << " and are split B-ways, alternatingly +10% and -10%\n" + << " with the last batch sized to make it fit\n" + << " implies at least residual masking for correctness\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" + << "\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_block( + DeviceAllocation& block, + uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) { + + switch (init_style) { + case InitStyle::kOne: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + break; + } + case InitStyle::kRandom: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) 0, (Element) 1); + break; + } + case InitStyle::kLinearStride1: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (j % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kLinearStride128: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (i % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kNone: { + break; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_tc_s = 0; + double tops_exp2_s = 0; + double tbytes_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + bool kIsVarlen, + class TileShape, + class DispatchPolicy, + class ActiveMask, + class... KernelOptions +> +struct FwdRunner { + +#ifdef FP8 + using Element = cutlass::float_e4m3_t; +#else + using Element = cutlass::half_t; +#endif + + using ElementAccumulatorQK = float; + using ElementAccumulatorPV = float; + using ElementOut = cutlass::half_t; + + // Q K D (B H) + using ProblemShapeRegular = cute::tuple, int>>; + using ProblemShapeVarlen = cute::tuple, int>>; + using ProblemShapeType = std::conditional_t; + + using StrideQ = cute::tuple, int>>; // Q D (H_G H_R B) + using StrideK = cute::tuple, int>>; // K D (H_G H_R B) + using StrideV = StrideK; + using StrideO = StrideQ; + using StrideLSE = cute::tuple<_1, cute::tuple, int>>; // Q (H_G H_R B) + + static constexpr bool kIsPersistent = find_option_t::value; + using TileScheduler = std::conditional_t; + + using Mainloop = + cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, StrideQ, StrideK, StrideV, + ActiveMask + >; + using Operation = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized< + ProblemShapeType, + Mainloop, + cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized< + ElementOut, ElementAccumulatorPV, + typename Mainloop::TileShapePV, + StrideO, StrideLSE + >, + TileScheduler + >>; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + uint64_t seed = 0; + + DeviceAllocation block_Q; + DeviceAllocation block_K; + DeviceAllocation block_V; + DeviceAllocation block_O; + DeviceAllocation block_LSE; + DeviceAllocation block_ref_O; + DeviceAllocation block_ref_LSE; + + std::vector cumulative_seqlen_q; + std::vector cumulative_seqlen_kv; + DeviceAllocation device_cumulative_seqlen_q; + DeviceAllocation device_cumulative_seqlen_kv; + + // + // Methods + // + bool verify(const ProblemShapeType& problem_shape) { + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + select<0,2,3>(problem_shape), + stride_Q); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + select<1,2,3>(problem_shape), + stride_K); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + select<1,2,3>(problem_shape), + stride_V); + + Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), + select<0,2,3>(problem_shape), + stride_O); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), + select<0,3>(problem_shape), + stride_LSE); + + fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{}); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; + reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); + + bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_O) { + std::cerr << "failed O: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + // reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); + + bool passed_LSE = true; // future work + // bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + // if ( ! passed_LSE) { + // std::cerr << "failed LSE: max diff " << max_diff + // << " mean " << mean_diff << std::endl; + // } + + return passed_O && passed_LSE; + } + + template + auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) { + int num_batches = get<3,1>(problem_size); + + // generate Q as --b times + // gaussian (--Q, --Q / 2) sampled positive + // track cumulative + std::mt19937 rng(0x202305151552ull); + std::normal_distribution dist_q(get<0>(problem_size), get<0>(problem_size) / 2); + std::normal_distribution dist_kv(get<1>(problem_size), get<1>(problem_size) / 2); + std::cout << "N: " << num_batches << ", Q: " << get<0>(problem_size) << ", KV: " << get<1>(problem_size) << std::endl; + + auto generate_positive_int = [](auto& dist, auto& gen) { + int result = 0; + do { + result = static_cast(dist(gen)); + } while (result <= 0); + return result; + }; + + cumulative_seqlen_q = {0}; + cumulative_seqlen_kv = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + + for (int i = 0; i < num_batches; i++) { + int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng); + int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng); + + total_seqlen_q += seqlen_q; + total_seqlen_kv += seqlen_kv; + + max_seqlen_q = std::max(max_seqlen_q, seqlen_q); + max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + + cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); + cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + } + std::cout << "Q max: " << max_seqlen_q << " total: " << total_seqlen_q << " vs even " << num_batches * get<0>(problem_size) << std::endl; + std::cout << "KV max: " << max_seqlen_kv << " total: " << total_seqlen_kv << " vs even " << num_batches * get<1>(problem_size) << std::endl; + + ProblemShape problem_size_for_init = problem_size; + get<3,1>(problem_size_for_init) = 1; + get<0>(problem_size_for_init) = total_seqlen_q; + get<1>(problem_size_for_init) = total_seqlen_kv; + + ProblemShapeType problem_size_for_launch; + + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<2>(problem_size_for_launch) = get<2>(problem_size); + get<3>(problem_size_for_launch) = get<3>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + + /// Initialize operands to be used in the GEMM and reference GEMM + + ProblemShapeType initialize(const Options& options) { + int h_r = options.h / options.h_k; + assert(options.h % options.h_k == 0); + auto problem_shape_in = cute::make_tuple(options.q, options.k, options.d, cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (kIsVarlen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } + else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + get<2>(problem_size) = cutlass::round_up(get<2>(problem_size), 8); // alignment + + auto shape_QO = select<0,2,3>(problem_size); + auto shape_KV = select<1,2,3>(problem_size); + auto shape_LSE = select<0,3>(problem_size); + + int SQ = size<0>(problem_size); + int SK = size<1>(problem_size); + int D = size<2>(problem_size); + int H = size<3,0>(problem_size); + int H_K = size<3,0,1>(problem_size); + int H_Q = size<3,0,0>(problem_size); + int B = size<3,1>(problem_size); + + stride_Q = make_stride(H*D , _1{}, make_stride(make_stride(D, H_Q*D), H*D*SQ)); + stride_O = stride_Q; + stride_K = make_stride(H_K*D , _1{}, make_stride(make_stride(_0{}, D), H_K*D*SK)); + stride_V = stride_K; + stride_LSE = make_stride(_1{}, make_stride(make_stride(SQ, SQ*H_Q), SQ*H)); + + if (kIsVarlen) { + get<2,1>(stride_Q) = 0; + get<2,1>(stride_K) = 0; + get<2,1>(stride_V) = 0; + get<2,1>(stride_O) = 0; + get<1,1>(stride_LSE) = 0; + } + + block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + block_LSE.reset(size(shape_LSE)); + block_ref_O.reset(size(shape_QO)); + block_ref_LSE.reset(size(shape_LSE)); + + initialize_block(block_Q, seed + 2023, options.init_style_q); + initialize_block(block_K, seed + 2022, options.init_style_k); + initialize_block(block_V, seed + 2021, options.init_style_v); + + if ( ! cumulative_seqlen_q.empty()) { + device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + device_cumulative_seqlen_q.copy_from_host( + cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + } + if ( ! cumulative_seqlen_kv.empty()) { + device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + device_cumulative_seqlen_kv.copy_from_host( + cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + } + + if constexpr (kIsVarlen) { + get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get(); + get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get(); + } + + return problem_shape; + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + + ProblemShapeType problem_shape = initialize(options); + + typename Operation::Arguments arguments{ + problem_shape, + { block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V }, + { block_O.get(), stride_O, + block_LSE.get(), stride_LSE }, + hw_info + }; + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops; + if (kIsVarlen) { + flops = 0.0; + for (int i = 0; i < size<3,1>(problem_shape); i++) { + flops += (cumulative_seqlen_q[i+1] - cumulative_seqlen_q[i]) + * 1.0 + * (cumulative_seqlen_kv[i+1] - cumulative_seqlen_kv[i]); + } + } + else { + flops = 1.0; + flops *= static_cast(size<0>(problem_shape)); + flops *= static_cast(size<1>(problem_shape)); + flops *= static_cast(size<3,1>(problem_shape)); + } + flops *= 4.0 * (std::is_same_v ? 0.5 : 1.0); + flops *= static_cast(size<2>(problem_shape)); + flops *= static_cast(size<3,0>(problem_shape)); + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_tc_s = tflops_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_shape); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms << "ms, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, const char* name, auto... kernel_options) { + if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) { + return; + } + if (options.varlen) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + } + else + { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + } + }; + + using HeadDim = _128; + + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, const char* name, auto... kernel_options) { + if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) { + return; + } + if (options.varlen) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + } + else + { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + } + }; + + using HeadDim = _64; + + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, const char* name, auto... kernel_options) { + if (options.varlen) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + } + else { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + } + }; + + using HeadDim = _32; + +#ifdef FP8 + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); +#endif +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) { + std::cout + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability major 10) and CUDA 12.8 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + if (options.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + else { + hw_info.sm_count = options.sm_count; + } + + std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " "; + std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + auto with_mask = [&](auto fn) { + if (options.causal) { + fn(CausalMask{}); + } + else if (options.residual) { + fn(ResidualMask{}); + } + else { + fn(NoMask{}); + } + }; + + with_mask([&](auto fusion) { + if (options.d <= 32) { + run_fwd_32(fusion, options, hw_info); + } + else if (options.d <= 64) { + run_fwd_64(fusion, options, hw_info); + } + else if (options.d <= 128) { + run_fwd_128(fusion, options, hw_info); + } + else { + std::cout << "No kernel instantiated for d=" << options.d << std::endl; + } + }); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu new file mode 100644 index 0000000000..9ac6f5894f --- /dev/null +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu @@ -0,0 +1,832 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example implementation of fused multi-head attention for the NVIDIA Blackwell SM100 + architecture using CUTLASS 3. + + MQA/GQA + ------- + + The head dimension can be represented as a tuple, where the K/V strides in the + first dimension is zero. This has the effect of MQA or GQA. + * MHA is (head_size:head_stride). + * MQA is (head_size:head_stride) in Q and (head_size:_0) in K and V. + * GQA is (grouped_heads,heads_kv):(head_stride,grouped_heads*head_stride) in Q + and (grouped_heads,heads_kv):(0,head_stride) in K and V + + Example usage: + $ ./examples/77_blackell_fmha/77_blackell_fmha_gen_fp8 \ + --b=2048 --h=2048 --d=2048 --k=2048 +*/ + +#define DSHOW(x) print(#x ": "); print(x); print("\n"); +#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n"); + +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "reference/fmha_fwd_gen_reference.hpp" +#include "reference/reference_abs_error.hpp" + +#include "device/fmha.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_gen_mainloop_warpspecialized.hpp" +#include "collective/sm100_fmha_gen_epilogue_warpspecialized.hpp" +#include "kernel/sm100_fmha_gen_kernel_warpspecialized.hpp" +#include "kernel/fmha_tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class InitStyle { + kZero, kOne, kLinearStride128, kLinearStride1, kRandom, kNone +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help = false; + bool error = false; + + int b = 1; + int h = 1; + int h_k = 1; + int k = 512; + int d = 128; + int iterations = 3; + bool verify = false; + bool verbose = false; + bool remap = false; + bool varlen = false; + bool cache_only = false; + + int sm_count = 0; + + std::string kernel_filter; + bool clear_cache = false; + + InitStyle init_style_q = InitStyle::kRandom; + InitStyle init_style_cache_k = InitStyle::kRandom; + InitStyle init_style_cache_v = InitStyle::kRandom; + InitStyle init_style_new_k = InitStyle::kRandom; + InitStyle init_style_new_v = InitStyle::kRandom; + + static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) { + std::string s; + cmd.get_cmd_line_argument(name, s, s); + if (s.empty()) { + dst = src; + } + else { + if (s == "r") { + dst = InitStyle::kRandom; + } + else if (s == "0") { + dst = InitStyle::kZero; + } + else if (s == "1") { + dst = InitStyle::kOne; + } + else if (s == "d") { + dst = InitStyle::kLinearStride1; + } + else if (s == "s") { + dst = InitStyle::kLinearStride128; + } + else if (s == "n") { + dst = InitStyle::kNone; + } + else { + std::cout << "Error: " << s << " is not a valid input type.\n"; + std::exit(-1); + } + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("d", d, defaults.d); + cmd.get_cmd_line_argument("h", h, -1); + if (h == -1) h = 2048 / d; + + cmd.get_cmd_line_argument("h_k", h_k, -1); + if (h_k == -1) h_k = h; + + cmd.get_cmd_line_argument("k", k, defaults.k); + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + varlen = cmd.check_cmd_line_flag("varlen"); + remap = cmd.check_cmd_line_flag("remap"); + cache_only = cmd.check_cmd_line_flag("cache-only"); + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); + + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_cache_k, defaults.init_style_cache_k); + get_init_style_argument(cmd, "init-style", init_style_cache_v, defaults.init_style_cache_v); + get_init_style_argument(cmd, "init-style", init_style_new_k, defaults.init_style_new_k); + get_init_style_argument(cmd, "init-style", init_style_new_v, defaults.init_style_new_v); + get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q); + get_init_style_argument(cmd, "init-style-cache-k", init_style_cache_k, init_style_cache_k); + get_init_style_argument(cmd, "init-style-cache-v", init_style_cache_v, init_style_cache_v); + get_init_style_argument(cmd, "init-style-new-k", init_style_new_k, init_style_new_k); + get_init_style_argument(cmd, "init-style-new-v", init_style_new_v, init_style_new_v); + + clear_cache = cmd.check_cmd_line_flag("clear-cache"); + + cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "77_blackwell_fmha_gen\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " fused multi-head attention forward-pass gen-phase kernels targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --h= Sets the H extent\n" + << " --h_k= Sets the H_K/V extent (for GQA/MQA)\n" + << " --k= Sets the K extent (sampled around this length)\n" + << " --d= Sets the D extentn" + << " --iterations= Benchmarking iterations\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --remap Enables batch index remapping\n" + << " --cache-only Only use data from KV cache, no reading or inserting new entry\n" + << " --varlen Varies sequence length between cache entries\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --clear-cache Clears the cache before benchmarking runs\n" + << " --kernel-filter= Sets regexp to match kernel against\n" + << "\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_block( + DeviceAllocation& block, + uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) { + + switch (init_style) { + case InitStyle::kZero: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 0, (Element) 0); + break; + } + case InitStyle::kOne: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + break; + } + case InitStyle::kRandom: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) 0, (Element) 1); + break; + } + case InitStyle::kLinearStride1: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (j % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kLinearStride128: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (i % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kNone: { + break; + } + + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool supported = false; + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_tc_s = 0; + double tops_exp2_s = 0; + double tbytes_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ClearCache { + const int size = 1024 * 1024 * 1024 / 4; + DeviceAllocation data; + bool active = false; + + ClearCache() = default; + + void set_active(bool the_active) { + active = the_active; + if (active) { + data.reset(size); + } + else { + data.reset(0); + } + } + + void operator ()() { + if (active) { + initialize_block(data, 0x49314, InitStyle::kRandom); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class KernelType { + UMMA_P, UMMA_I +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner { + + using Element = cutlass::float_e5m2_t; + using ElementAcc = float; + using ElementOut = cutlass::half_t; + + using ProblemShape = Shape<_1, int, int, Shape, int>>; + + using StrideQ = Stride<_0, _1, Stride, int>>; + using StrideNewK = Stride<_0, _1, Stride, int>>; + using StrideCacheK = Stride, int>>; + using StrideNewV = StrideNewK; + using StrideCacheV = StrideCacheK; + using StrideO = StrideQ; + + using Kernel = + cutlass::fmha::kernel::Sm100FmhaGenKernelWarpspecialized< + ProblemShape, + cutlass::fmha::collective::Sm100FmhaGenMainloopWarpspecialized< + Element, ElementAcc, ElementAcc, ElementOut, + TileShape, + StrideQ, StrideNewK, StrideNewV, + StrideCacheK, StrideCacheV, StrideO + >, + cutlass::fmha::collective::Sm100FmhaGenEpilogueWarpspecialized, + std::conditional_t + >; + + using Operation = cutlass::fmha::device::FMHA; + + StrideQ stride_q; + StrideNewK stride_new_k; + StrideNewV stride_new_v; + StrideCacheK stride_cache_k; + StrideCacheV stride_cache_v; + StrideO stride_o; + uint64_t seed = 0; + + std::vector seqlen_kv; + + DeviceAllocation block_seqlen_kv; + DeviceAllocation block_cache_batch_idx; + DeviceAllocation block_q; + DeviceAllocation block_new_k; + DeviceAllocation block_new_v; + DeviceAllocation block_cache_k; + DeviceAllocation block_cache_v; + DeviceAllocation block_o; + + DeviceAllocation block_ref_cache_k; + DeviceAllocation block_ref_cache_v; + DeviceAllocation block_ref_o; + + ClearCache clear_cache; + + bool verify(const ProblemShape& problem_shape) { + + Tensor mQ = make_tensor(make_gmem_ptr(block_q.get()), select<0,2,3>(problem_shape), stride_q); + Tensor mNewK = make_tensor(make_gmem_ptr(block_new_k.get()), select<0,2,3>(problem_shape), stride_new_k); + Tensor mNewV = make_tensor(make_gmem_ptr(block_new_v.get()), select<0,2,3>(problem_shape), stride_new_v); + Tensor mCacheK = make_tensor(make_gmem_ptr(block_ref_cache_k.get()), select<1,2,3>(problem_shape), stride_cache_k); + Tensor mCacheV = make_tensor(make_gmem_ptr(block_ref_cache_v.get()), select<1,2,3>(problem_shape), stride_cache_v); + Tensor mO = make_tensor(make_gmem_ptr(block_ref_o.get()), select<0,2,3>(problem_shape), stride_o); + + fmha_fwd_gen_reference( + problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(), + mQ, mNewK, mNewV, mCacheK, mCacheV, mO); + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; + reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff); + bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_O) { + std::cerr << "failed O: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff); + bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if ( ! passed_K) { + std::cerr << "failed Cache K: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff); + bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if ( ! passed_V) { + std::cerr << "failed Cache V: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + return passed_O && passed_K && passed_V; + } + + ProblemShape initialize(const Options& options) { + + clear_cache.set_active(options.clear_cache); + + std::vector cache_batch_idx; + + // set up stides and sizes + if (options.remap) { + for (int i = 0; i < options.b; i++) { + cache_batch_idx.push_back(i); + } + std::mt19937 rng(0x202305291305ull); + std::shuffle(cache_batch_idx.begin(), cache_batch_idx.end(), rng); + } + + seqlen_kv = std::vector(options.b, options.k); + if (options.varlen) { + std::mt19937 rng(0x202305151552ull); + std::normal_distribution dist_kv(options.k, options.k / 2); + + auto generate_positive_int = [](auto& dist, auto& gen) { + int result = 0; + do { + result = static_cast(dist(gen)); + } while (result <= 0); + return result; + }; + + for (int i = 0; i < options.b; i++) { + seqlen_kv[i] = generate_positive_int(dist_kv, rng); + } + } + + int max_seqlen_kv = 0; + for (auto e : seqlen_kv) { + // if (options.varlen) std::cout << "seqlen " << e << std::endl; + max_seqlen_kv = std::max(e, max_seqlen_kv); + } + + ProblemShape result = make_shape(_1{}, max_seqlen_kv + 1, options.d, make_shape(make_shape(options.h / options.h_k, options.h_k), options.b)); + + stride_q = make_stride(_0{}, _1{}, make_stride(make_stride(options.d, options.d * size<3,0,0>(result)), options.d * size<3,0>(result))); + stride_new_k = make_stride(_0{}, _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result))); + stride_cache_k = make_stride(options.d * size<3,0,1>(result), _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result) * get<1>(result))); + + stride_new_v = stride_new_k; + stride_cache_v = stride_cache_k; + stride_o = stride_q; + + block_q.reset(options.b * get<2,1>(stride_q)); + if (! options.cache_only) { + block_new_k.reset(options.b * get<2,1>(stride_new_k)); + block_new_v.reset(options.b * get<2,1>(stride_new_v)); + } + block_cache_k.reset(options.b * get<2,1>(stride_cache_k)); + block_cache_v.reset(options.b * get<2,1>(stride_cache_v)); + block_o.reset(options.b * get<2,1>(stride_o)); + + block_ref_cache_k.reset(options.b * get<2,1>(stride_cache_k)); + block_ref_cache_v.reset(options.b * get<2,1>(stride_cache_v)); + block_ref_o.reset(options.b * get<2,1>(stride_o)); + + initialize_block(block_q, seed + 2023, options.init_style_q); + if (! options.cache_only) { + initialize_block(block_new_k, seed + 2022, options.init_style_new_k); + initialize_block(block_new_v, seed + 2021, options.init_style_new_v); + } + + initialize_block(block_cache_k, seed + 2024 - 2025, options.init_style_cache_k); + initialize_block(block_cache_v, seed + 2025, options.init_style_cache_v); + + block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size()); + block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size()); + block_seqlen_kv.reset(seqlen_kv.size()); + block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size()); + + if (! cache_batch_idx.empty()) { + block_cache_batch_idx.reset(cache_batch_idx.size()); + block_cache_batch_idx.copy_from_host(cache_batch_idx.data(), cache_batch_idx.size()); + } + + return result; + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + auto problem_shape = initialize(options); + + typename Operation::Arguments arguments{ + problem_shape, + block_seqlen_kv.get(), block_cache_batch_idx.get(), + block_q.get(), stride_q, + block_new_k.get(), stride_new_k, + block_new_v.get(), stride_new_v, + block_cache_k.get(), stride_cache_k, + block_cache_v.get(), stride_cache_v, + block_o.get(), stride_o, + hw_info + }; + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + // std::cerr << "This kernel is not supported. Last CUDA error is: " + // << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + example_result.supported = true; + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + float total_runtime_ms = 0; + + for (int i = 0; i < options.iterations; i++) { + + clear_cache(); + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Stop profiling loop + // + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + total_runtime_ms += runtime_ms; + + } + + float runtime_ms = total_runtime_ms / static_cast(options.iterations); + + double bytes; + bytes = 0.0; + bytes += double(sizeof(Element) * size<3>(problem_shape)); // Q + bytes += double(sizeof(ElementOut) * size<3>(problem_shape)); // O + bytes += 2.0 * double(sizeof(Element) * size<3>(problem_shape) / size<3,0,0>(problem_shape)); // NewK, NewV + double total_seqlen_kv = 0; + for (auto e : seqlen_kv) { + total_seqlen_kv += double(e + 1); + } + bytes += 2.0 * double(sizeof(Element) * size<3,0,1>(problem_shape) * total_seqlen_kv); // CacheK, CacheV + bytes *= static_cast(size<2>(problem_shape)); + double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tbytes_s = tbytes_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_shape); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tbytes_s << " TB/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms << "ms, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major < 10) { + std::cout + << "This example requires a GPU of NVIDIA's Blackwell Architecture or " + << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + if (options.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + else { + hw_info.sm_count = options.sm_count; + } + + std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " K " << options.k << " D " << options.d << " "; + std::cout << "Gen" << " " << (options.varlen ? "Variable" : "Uniform") << " " << (options.remap ? "Remap" : "Linear") << " "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + using UMMA = true_type; + using FFMA2 = false_type; + auto run = [&](const char* name, auto kernel_type, auto tile, auto thr) { + if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) { + return; + } + ExampleRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + + #define RUN(MODE, m, n, k, tm, tn, tk) \ + run( \ + #MODE " " #m "x" #n "x" #k " / " #tm "x" #tn "x" #tk, \ + std::integral_constant{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \ + ) + + RUN(UMMA_I, 128, 64, 128, 1, 1, 1); + RUN(UMMA_I, 128, 128, 128, 1, 1, 1); + RUN(UMMA_I, 128, 256, 128, 1, 1, 1); + RUN(UMMA_P, 128, 64, 128, 1, 1, 1); + RUN(UMMA_P, 128, 128, 128, 1, 1, 1); + RUN(UMMA_P, 128, 256, 128, 1, 1, 1); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt new file mode 100644 index 0000000000..4c9e784a58 --- /dev/null +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -0,0 +1,105 @@ +# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +set_property( + SOURCE 77_blackwell_fmha.cu + PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v") + +set_property( + SOURCE 77_blackwell_fmha_gen.cu + PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v") + +set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) +set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) +set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen) +set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify) +set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify) + +set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify) +set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen) +set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify) +set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify) +set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap) +set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only) + +if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))) + if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") + cutlass_example_add_executable( + 77_blackwell_fmha_fp8 + 77_blackwell_fmha.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_CAUSAL + # TEST_VARLEN + # TEST_HDIM64 + # TEST_GQA) + ) + target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8) + + cutlass_example_add_executable( + 77_blackwell_fmha_gen_fp8 + 77_blackwell_fmha_gen.cu + TEST_COMMAND_OPTIONS + TEST_GEN_BASIC + # TEST_GEN_VARLEN + # TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8) + + cutlass_example_add_executable( + 77_blackwell_fmha_fp16 + 77_blackwell_fmha.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_CAUSAL + # TEST_VARLEN + # TEST_HDIM64 + # TEST_GQA) + ) + target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + + cutlass_example_add_executable( + 77_blackwell_fmha_gen_fp16 + 77_blackwell_fmha_gen.cu + TEST_COMMAND_OPTIONS + TEST_GEN_BASIC + # TEST_GEN_VARLEN + # TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + endif() +endif() diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md new file mode 100644 index 0000000000..8766f081ff --- /dev/null +++ b/examples/77_blackwell_fmha/README.md @@ -0,0 +1,23 @@ +# FMHA for Blackwell: Forward + +This sample provides code for fused multi-head attention forward, context, or generation phase. +It supports HeadDims of 32, 64, and 128, and fp8, fp16, and bf16 input data types. + +For forward or context usage, use an M-blocking (Seqlen-Q) of 256 and an N-blocking (Seqlen-K) of 128. +For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit is currently 32 for actual Num-Groups), and a N-blocking (Seqlen-K) of 64, 128 or 256. + +Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns. + +For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively. + +The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel. +The kernel and collective layer are then formulated to be fmha-specific. +The design assigns two tiles to each threadblock, and pingpongs between them in terms of matrix-matrix multiplication and softmax. + +The example builds four binaries, showcasing the context and generation usage for fp8 and fp16. +For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them. + +To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easiest customization point. +The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements. +It is well-suited for applying masks or activations. +More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA. diff --git a/examples/77_blackwell_fmha/collective/fmha_common.hpp b/examples/77_blackwell_fmha/collective/fmha_common.hpp new file mode 100644 index 0000000000..c60d9e953e --- /dev/null +++ b/examples/77_blackwell_fmha/collective/fmha_common.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +template +CUTLASS_DEVICE +void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } + else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp new file mode 100644 index 0000000000..85138b0bd0 --- /dev/null +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + return; + } +}; + +struct ResidualMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct CausalMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +struct VariableLength { + int max_length; + int* cumulative_length = nullptr; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template struct is_variable_length : std::false_type {}; +template<> struct is_variable_length : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length::value; + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v>) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v>) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } + else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template<> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000..8240080170 --- /dev/null +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element, + class ElementAcc, + class TileShape, // Q, D, _ + class StrideO, // Q, D, B + class StrideLSE // Q, B +> +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + +// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); +// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_O; + StrideO dO; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_,_,_0{}) + )); + + + struct Params { + TMA_O tma_store_o; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + + auto ptr_O = args.ptr_O; + StrideO dO = args.dO; + auto problem_shape_O = select<0,2,3>(problem_shape); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dO) = get<0>(dO); + get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + // offset ptr by the amount we add back in later + ptr_O -= max_length_q * get<0>(dO); + } + } + + auto tma_store_o = make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(ptr_O, problem_shape_O, dO), + SmemLayoutO{}(_,_,_0{}) + ); + + return { + tma_store_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + + BlkCoord blk_coord = blk_coord_in; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape)); + // offset mode 0 by (max_length - real_length) + // offset mode 3,1 by cumulative_length + real_length + // the ptr is already offset by - max_length + // so in total this achieves + int offs_0 = 0; + int offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + offs_0 = max_length_q - get<0>(problem_shape); + offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); + get<2,1>(blk_coord) = 0; + } + } + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000000..3f063f564a --- /dev/null +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class TileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1> +> +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = TileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaUmmaAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadKV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + + static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v), "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CountingTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); +// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); +// tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); +// tTMEM_LOADsO_i.data() = tTMEM_LOADsO_i.data().get() + sO.layout()(_0{}, i * kCorrectionTileSize, _0{}); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + //tTMrO(j) = scale * tTMrO(j); + //tTMrO(j+1) = scale * tTMrO(j+1); + } + + copy_out(i); + } + } + + template + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp new file mode 100644 index 0000000000..3d1e1e8be3 --- /dev/null +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element_, + class StrideO_ +> +struct Sm100FmhaGenEpilogueWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + + using SmemLayoutO = Layout>; + using SmemLayoutO_ = SmemLayoutO; + using Element = Element_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideOOrig{}, 0)); + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_o; + StrideO dO; + }; + + using Params = Arguments; + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaGenEpilogueWarpspecialized(const Params& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + /* no-op */ + } + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + /* no-op */ + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp new file mode 100644 index 0000000000..38b2619661 --- /dev/null +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -0,0 +1,1116 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ElementOut_, + class TileShape_, + class StrideQ_, + class StrideNewK_, + class StrideNewV_, + class StrideK_, + class StrideV_, + class StrideO_, + class Mask_ = ResidualMask, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_1, _2, _1> +> +struct Sm100FmhaGenMainloopWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ElementAcc = ElementPV_; + using ElementOut = ElementOut_; + using TileShape = TileShape_; + using StrideQOrig = StrideQ_; + using StrideQ = decltype(replace<0>(StrideQ_{}, 0)); + using StrideNewK = StrideNewK_; + using StrideNewV = StrideNewV_; + using StrideCacheK = StrideK_; + using StrideCacheV = StrideV_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideO_{}, 0)); + using Mask = Mask_; + + static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; + static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineUmmaConsumerAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineUmmaConsumerAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v), "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadCpAsyncWarpspecialized< + Element, StrideQ, StrideNewK, StrideNewV, StrideCacheK, StrideCacheV, + TensorStorage, CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + PipelineQ, PipelineKV, TileShape, Mask + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CountingTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); +// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1, + GTensor& gO, CTensor const& cO, Shape const& g_shape, + Epilogue const& epilogue) { + + using ElementOut = typename GTensor::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOgO = mma.get_slice(0).partition_C(gO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO0 = tOtO_i; + tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO_i; + tOtO1.data() = tOtO1.data().get() + uint32_t(TmemAllocation::O1); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0); + Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_LOADgO = thr_tmem_load.partition_D(tOgO_i); + + float row_max = std::max(v0(kIdxFinalRowMax), v1(kIdxFinalRowMax)); + float adj0 = ::exp2f(scale_softmax_log2 * (v0(kIdxFinalRowMax) - row_max)); + float adj1 = ::exp2f(scale_softmax_log2 * (v1(kIdxFinalRowMax) - row_max)); + float row_sum = adj0 * v0(kIdxFinalRowSum) + adj1 * v1(kIdxFinalRowSum); + float scale0 = scale_out * adj0 / row_sum; + float scale1 = scale_out * adj1 / row_sum; + + float2 scale0_f32x2 = make_float2(scale0, scale0); + float2 scale1_f32x2 = make_float2(scale1, scale1); + + // loop: + // TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 128 / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0; + tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1; + tTMEM_LOADtO1_i.data() = tTMEM_LOADtO1_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADgO_i = tTMEM_LOADgO; + tTMEM_LOADgO_i.data() = tTMEM_LOADgO_i.data().get() + i * kCorrectionTileSize * stride<1>(gO); + + Tensor tTMrO0 = make_tensor(shape(tTMEM_LOADcO)); + Tensor tTMrO1 = make_tensor(shape(tTMEM_LOADcO)); + + copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0); + copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO0); j += 2) { + float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1)); + float2 in1 = make_float2(tTMrO1(j), tTMrO1(j+1)); + float2 out; + cute::mul(out, scale0_f32x2, in0); + cute::fma(out, scale1_f32x2, in1, out); + tTMrO0(j) = out.x; + tTMrO0(j+1) = out.y; + } + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO0); + + Tensor tCs = recast(tTMrO0); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); + + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); + } + } + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + //tTMrO(j) = scale * tTMrO(j); + //tTMrO(j+1) = scale * tTMrO(j+1); + } + + copy_out(i); + } + } + + template + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + Epilogue const& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS0 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS0); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + Tensor tTMEM_LOADVrS1 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS1); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + auto pipeline_o_release_state = pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + auto g_shape = select<0,2>(problem_shape); + auto mO = make_tensor(make_gmem_ptr(epilogue.params.ptr_o), append<3>(select<0,1>(TileShapePV{}), get<3>(problem_shape)), epilogue.params.dO); + auto gO = mO(_, _, get<2>(blk_coord)); + + correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, tTMEM_LOADVrS1, gO, cO, g_shape, epilogue); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000..c201d4f040 --- /dev/null +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp @@ -0,0 +1,395 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideNewK, + class StrideNewV, + class StrideCacheK, + class StrideCacheV, + class TensorStorage, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class PipelineQ, + class PipelineKV, + class TileShape, + class Mask +> +struct Sm100FmhaLoadCpAsyncWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + + const int* cache_batch_idx; + + const Element* ptr_q; + StrideQ dQ; + + const Element* ptr_new_k; + StrideNewK dNewK; + const Element* ptr_new_v; + StrideNewV dNewV; + + Element* ptr_cache_k; + StrideCacheK dCacheK; + Element* ptr_cache_v; + StrideCacheV dCacheV; + }; + + using Params = Arguments; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + } + + template + CUTLASS_DEVICE auto constexpr transpose(Tensor const& t) { + CUTE_STATIC_ASSERT_V(rank(t) == _2{}); + return t.compose(make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{}))); + } + + template< + class CAtom, class TA, class TB, + class CountTensor, class CountLimit, + class SrcTensor, class DstTensor + > + CUTLASS_DEVICE void copy_with_limit( + TiledCopy const& tiled_copy, + CountTensor const& c, CountLimit const& l, + SrcTensor const& src, DstTensor&& dst) { + + //copy(tiled_copy, src, dst); +#if 1 + auto c_f = make_tensor(c.data(), flatten(c.layout())); + auto src_f = make_tensor(src.data(), flatten(src.layout())); + auto dst_f = make_tensor(dst.data(), flatten(dst.layout())); + auto c_v = group_modes<1,rank_v>(c_f); + auto src_v = group_modes<1,rank_v>(src_f); + auto dst_v = group_modes<1,rank_v>(dst_f); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(src_v); i++) { + if (elem_less(c_v(_0{}, i), l)) { + copy(CAtom{}, src_v(_, i), dst_v(_, i)); + } + else { + clear(dst_v(_, i)); + } + } +#endif + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + mask_tile_count *= 2; + + int warp_idx = (threadIdx.x / 32) % 2; + int thread_idx = warp_idx * 32 + (threadIdx.x % 32); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + auto blk_coord_cache = blk_coord; + if (params.cache_batch_idx != nullptr) { + get<2,1>(blk_coord_cache) = params.cache_batch_idx[get<2,1>(blk_coord_cache)]; + } + + // Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn + // two pipes: Q and KV + auto cQ = make_identity_tensor(select<0,2>(TileShape{})); + auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), append<3>(select<0,2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ); + auto gQ = mQ(_, _, get<2>(blk_coord)); + auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + auto tSgQ = thr_mma_qk.partition_A(gQ); + auto tScQ = thr_mma_qk.partition_A(cQ); + + auto atom_q_tv = Layout, Shape<_16, _16>>, Stride, Stride<_1, _1024>>>{}; + auto atom_kv_tv = Layout, Shape<_16, _4>>, Stride, Stride<_1, _1024>>>{}; + + auto tiled_copy_q = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_q_tv, + make_layout(shape(tSgQ), replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{}))))); + + auto thr_copy_q = tiled_copy_q.get_slice(thread_idx); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQ = thr_copy_q.partition_S(tSgQ); + auto tQcQ = thr_copy_q.partition_S(tScQ); + + auto limitQ = append<2>(get<0>(problem_shape), _128{}); + + // Q1 + int q0_index = get<0>(blk_coord); +// pipeline_q.producer_acquire(pipeline_q_producer_state); + + // copy_with_limit(tiled_copy_q, tQcQ, limitQ, tQgQ, tQsQ(_, _, _, _, pipeline_q_producer_state.index()); + auto load_q = [&](int q_index, auto& state) { + pipeline_q.producer_acquire(state); + +// using Vec = Element; +// auto vzero = Element(0); + // q is always loaded masked + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + //auto src = recast(tQgQ(_, _, _, _, q_index)); + auto src = recast(tQgQ(_, _, _, _)); + auto dst = recast(tQsQ(_, _, _, _, state.index())); + // auto c = tQcQ(_, _, _, _, q_index); + auto c = tQcQ(_, _, _, _); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitQ); + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + }; + + load_q(q0_index, pipeline_q_producer_state); +// pipeline_q.producer_commit(pipeline_q_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_q_producer_state; + + auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{})); + auto cK = make_tensor(cK_t.data(), make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t)))); + auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1,2,3>(problem_shape), params.dCacheK); + auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), Step{}); + auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + auto tSgK = thr_mma_qk.partition_B(gK); + auto tScK = thr_mma_qk.partition_B(cK); + + auto tSlK = thr_mma_qk.partition_B(make_tensor((Element*) nullptr, make_ordered_layout(select<1,2>(TileShapeQK{}), Step<_1, _0>{}))); + auto tiled_copy_k = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_kv_tv, + tSlK.layout()); + + auto thr_copy_k = tiled_copy_k.get_slice(thread_idx); + + auto tKsK = thr_copy_k.partition_D(sK); + auto tKgK = thr_copy_k.partition_S(tSgK); + auto tKcK = thr_copy_k.partition_S(tScK); + + int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0); + auto limitK = append<2>(seqlen_cache_kv, _128{}); + + auto cV_t = make_identity_tensor(select<1,2>(TileShapePV{})); + auto cV = make_tensor(cV_t.data(), make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t)))); + auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2,1,3>(problem_shape), select<1,0,2>(params.dCacheV)); + auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), Step{}); + auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + typename CollectiveMmaPV::TiledMma mma_pv; + ThrMMA thr_mma_pv = mma_pv.get_slice(0); + auto tOgV = thr_mma_pv.partition_B(gV); + auto tOcV = thr_mma_pv.partition_B(cV); + auto tOlV = thr_mma_pv.partition_B(make_tensor((Element*) nullptr, make_layout(select<1,2>(TileShapePV{})))); + + auto tiled_copy_v = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_kv_tv, + tOlV.layout()); + + auto thr_copy_v = tiled_copy_v.get_slice(thread_idx); + + auto tVsV = thr_copy_v.partition_D(sV); + auto tVgV = thr_copy_v.partition_S(tOgV); + auto tVcV = thr_copy_v.partition_S(tOcV); + + auto limitV = select<1,0>(limitK); + + int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{}); + + bool has_new = params.ptr_new_k != nullptr; + Tensor mNewK = make_tensor(make_gmem_ptr(params.ptr_new_k), select<1,2,3>(problem_shape), params.dNewK); + Tensor mNewV = make_tensor(make_gmem_ptr(params.ptr_new_v), select<1,2,3>(problem_shape), params.dNewV); + Tensor gNewK = mNewK(_, _, get<2>(blk_coord)); + Tensor gNewV = mNewV(_, _, get<2>(blk_coord)); + + auto load_k = [&](int k_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (k_index < full_tiles_cache) { + copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { +// using Vec = Element; +// auto vzero = Element(0); + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tKgK(_, _, _, _, k_index)); + auto dst = recast(tKsK(_, _, _, _, state.index())); + auto src2 = recast(gNewK); + auto c = tKcK(_, _, _, _, k_index); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitK); + if (get<0>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<1>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto load_v = [&](int v_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (v_index < full_tiles_cache) { + copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { +// using Vec = Element; +// auto vzero = Element(0); + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tVgV(_, _, _, _, v_index)); + auto dst = recast(tVsV(_, _, _, _, state.index())); + auto src2 = recast(gNewV); + int vlen = sizeof(Vec) / sizeof(Element); + auto c = tVcV(_, _, _, _, v_index); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitV); + if (get<1>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<0>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + // K1 + int k_index = 0; + int v_index = 0; + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + mask_tile_count -= 1; + + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + } + + // V1 + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + + if (has_new) { + for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) { + gK(seqlen_cache_kv, i, 0) = gNewK(0, i); + gV(i, seqlen_cache_kv, 0) = gNewV(0, i); + } + } + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000..1951056b2c --- /dev/null +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,316 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape +> +struct Sm100FmhaLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = problem_shape; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/77_blackwell_fmha/device/fmha.hpp b/examples/77_blackwell_fmha/device/fmha.hpp new file mode 100644 index 0000000000..f8406d3eb1 --- /dev/null +++ b/examples/77_blackwell_fmha/device/fmha.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/kernel/fmha_options.hpp b/examples/77_blackwell_fmha/kernel/fmha_options.hpp new file mode 100644 index 0000000000..d4faa8d215 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/fmha_options.hpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp b/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000000..35964cb6a3 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_h; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidb, bidh)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000..fbb8d362b1 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,519 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule +> +struct Sm100FmhaFwdKernelTmaWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadKV; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.epilogue, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state + ); + + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + shared_storage.epilogue, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp new file mode 100644 index 0000000000..92c7d3717d --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -0,0 +1,576 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaGenKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (warp_idx == 1) return WarpRole::MMA; // 12 + if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13 + if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7 + if (warp_idx == 8) return WarpRole::Correction; // 8 - 11 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 1; + static const int NumWarpsCorrection = 1; + static const int NumWarpsEpilogue = 0; + static const int NumWarpsLoad = 2; + + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 104; + static const int NumRegsOther = 248; + static const int NumRegsEmpty = 24; + + static const int NumWarps = 12; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule +> +struct Sm100FmhaGenKernelWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0)); + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using StrideQOrig = typename CollectiveMainloop::StrideQOrig; + using StrideOOrig = typename CollectiveMainloop::StrideOOrig; + using StrideQ = typename CollectiveMainloop::StrideQ; + using StrideO = typename CollectiveMainloop::StrideO; + using StrideCacheK = typename CollectiveMainloop::StrideCacheK; + using StrideCacheV = typename CollectiveMainloop::StrideCacheV; + using StrideNewK = typename CollectiveMainloop::StrideNewK; + using StrideNewV = typename CollectiveMainloop::StrideNewV; + using Element = typename CollectiveMainloop::Element; + using ElementAcc = typename CollectiveMainloop::ElementAcc; + using ElementOut = typename CollectiveMainloop::ElementOut; + + struct Arguments { + // _1, max_seqlen_k, head_dim, ((h_g, h_kv), b) + ProblemShapeIn problem_shape; + const int* seqlen_kv; + const int* cache_batch_idx; + + const Element* ptr_q; // 1 x D x (H x B) + StrideQOrig dQ; + const Element* ptr_new_k; // 1 x D x (H x B) + StrideNewK dNewK; + const Element* ptr_new_v; // 1 x D x (H x B) + StrideNewV dNewV; + + Element* ptr_cache_k; // seqlen_max x D x (H x B) + StrideCacheK dCacheK; + Element* ptr_cache_v; // seqlen_max x D x (H x B) + StrideCacheV dCacheV; + ElementOut* ptr_o; // 1 x D x (H x B) + StrideOOrig dO; + + cutlass::KernelHardwareInfo hw_info; + + ElementAcc scale_softmax = 0.0f; + }; + + struct Params { + ProblemShape problem_shape; + const int* seqlen_kv; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return true; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + ProblemShape problem_shape = replace<0>(args.problem_shape, static_cast(get<0>(args.problem_shape))); + CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{}); + StrideQ dQ = replace<0>(args.dQ, 0); + StrideO dO = replace<0>(args.dO, 0); + get<0>(problem_shape) = get<3,0,0>(args.problem_shape); + get<3,0,0>(problem_shape) = 1; + get<0>(dQ) = get<2,0,0>(dQ); + get<0>(dO) = get<2,0,0>(dO); + + typename CollectiveMainloop::Arguments mainloop_args { + { + args.cache_batch_idx, + args.ptr_q, dQ, + args.ptr_new_k, args.dNewK, + args.ptr_new_v, args.dNewV, + args.ptr_cache_k, args.dCacheK, + args.ptr_cache_v, args.dCacheV, + }, + args.scale_softmax + }; + + typename CollectiveEpilogue::Arguments epilogue_args { + args.ptr_o, dO, + }; + + return Params{ + problem_shape, + args.seqlen_kv, + CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace), + TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + ProblemShape result = problem_shape; + get<1>(result) = params.seqlen_kv[batch_idx]; + if (params.mainloop.load.ptr_new_k != nullptr) { + get<1>(result) += 1; + } + return result; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue(params.epilogue); + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.epilogue, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + shared_storage.epilogue, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp new file mode 100644 index 0000000000..003fff651e --- /dev/null +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAcc, + class ProblemShape, + class TensorQ, + class TensorNewK, + class TensorNewV, + class TensorCacheK, + class TensorCacheV, + class TensorO +> +void __global__ fmha_fwd_gen_reference_kernel( + ProblemShape problem_shape, + const int* seqlen_kv, const int* cache_batch_idx, + TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV, + TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) { + + using namespace cute; + extern __shared__ char mS_mem[]; + ElementAcc* mS = reinterpret_cast(mS_mem); + + float scale = 1.0f / std::sqrt(float(get<2>(problem_shape))); + + if (mNewK.data() != nullptr) { + // 1. copy in new_k to cache + for (int idx_h = blockIdx.x; idx_h < size<3,0,1>(problem_shape); idx_h += gridDim.x) { + for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) { + int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b; + for (int idx_d = threadIdx.x; idx_d < size<2>(problem_shape); idx_d += blockDim.x) { + mCacheK(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) = + mNewK(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b)); + mCacheV(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) = + mNewV(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b)); + } + } + } + } + + // 2. compute attention + for (int idx_h_kv = blockIdx.x; idx_h_kv < size<3,0,1>(problem_shape); idx_h_kv += gridDim.x) { + for (int idx_h_qo = blockIdx.y; idx_h_qo < size<3,0,0>(problem_shape); idx_h_qo += gridDim.y) { + int idx_h = idx_h_qo + size<3,0,0>(problem_shape) * idx_h_kv; + for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) { + int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b; + const int kDim = 128; + ElementAcc reg_o[kDim] = {0}; + ElementAcc row_max = -INFINITY; + ElementAcc row_sum = 0; + auto iteration = [&](auto const& tK, auto const& tV) { + ElementAcc reg_s = 0; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + ElementAcc eQ = mQ(_0{}, idx_d, make_coord(idx_h, idx_b)); + ElementAcc eK = tK(idx_d); + reg_s += eQ * eK; + } + + ElementAcc old_row_max = row_max; + row_max = std::max(row_max, reg_s); + + ElementAcc adjustment = std::exp(scale * (old_row_max - row_max)); + row_sum *= adjustment; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] *= adjustment; + } + + ElementAcc reg_p = std::exp(scale * (reg_s - row_max)); + row_sum += reg_p; + + for (int idx_d = 0; idx_d < kDim; idx_d++) { + ElementAcc eV = tV(idx_d); + reg_o[idx_d] += reg_p * eV; + } + }; + + for (int idx_s = threadIdx.x; idx_s < seqlen_kv[idx_b]; idx_s += blockDim.x) { + iteration(mCacheK(idx_s, _, make_coord(idx_h, idx_b_kv)), mCacheV(idx_s, _, make_coord(idx_h, idx_b_kv))); + } + + if (mNewK.data() != nullptr && threadIdx.x == 0) { + iteration(mNewK(_0{}, _, make_coord(idx_h, idx_b)), mNewV(_0{}, _, make_coord(idx_h, idx_b))); + } + + mS[threadIdx.x] = row_max; + __syncthreads(); + float old_row_max = row_max; + for (int i = 0; i < blockDim.x; i++) { + row_max = std::max(row_max, mS[i]); + } + __syncthreads(); + + ElementAcc adjustment = std::exp(scale * (old_row_max - row_max)); + row_sum *= adjustment; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] *= adjustment; + } + mS[threadIdx.x] = row_sum; + __syncthreads(); + + row_sum = 0; + for (int i = 0; i < blockDim.x; i++) { + row_sum += mS[i]; + } + __syncthreads(); + for (int idx_d = 0; idx_d < kDim; idx_d++) { + mS[idx_d] = 0; + } + __syncthreads(); + + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] /= row_sum; + atomicAdd(&mS[idx_d], reg_o[idx_d]); + } + + __syncthreads(); + for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) { + +// printf("O[%d,%d,%d] = %f\n", idx_d, idx_h, idx_b, mS[idx_d]); + mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast(mS[idx_d]); + } + } + } + } +} + +template< + class ElementAcc, + class ProblemShape, + class TensorQ, + class TensorNewK, + class TensorNewV, + class TensorCacheK, + class TensorCacheV, + class TensorO +> +void fmha_fwd_gen_reference( + ProblemShape problem_shape, + const int* seqlen_kv, const int* cache_batch_idx, + TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV, + TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) { + + using namespace cute; + + dim3 grid(get<3,0,1>(problem_shape), get<3,0,0>(problem_shape), get<3,1>(problem_shape)); + dim3 block(128); + int shared_mem = int(sizeof(ElementAcc)) * std::max(128, block.x); + assert(get<2>(problem_shape) == 128); + fmha_fwd_gen_reference_kernel<<>>( + problem_shape, seqlen_kv, cache_batch_idx, + mQ, mNewK, mNewV, mCacheK, mCacheV, mO + ); +} diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp new file mode 100644 index 0000000000..48d8110187 --- /dev/null +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "collective/fmha_fusion.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeIn, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Mask +> +void __global__ fmha_reference_kernel( + ProblemShapeIn problem_shape_in, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Mask mask) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + ElementAccumulator* mS = reinterpret_cast(mS_mem); + + ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + auto id = make_identity_tensor(make_shape(1, 1)); + for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) { + + auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in)); + auto coord_in = cute::make_tuple(idx_Q, _0{}, _0{}, coord_L); + auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in)); + + if (get<0,0>(coord) >= get<0>(problem_shape)) continue; + + int offset_Q = 0; + if constexpr (rank<0>(decltype(coord){}) == 2) { + offset_Q = get<0,1>(coord); + } + + int offset_K = 0; + if constexpr (rank<1>(decltype(coord){}) == 2) { + offset_K = get<1,1>(coord); + } + + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_D = 0; idx_D < size<2>(problem_shape); idx_D++) { + ElementAccumulator eQ = mQ(idx_Q + offset_Q, idx_D, idx_L); + ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L); + acc += eQ * eK; + } + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc; + mask.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + mS[idx_K] = frag(0); + } + + __syncthreads(); + + ElementAccumulator maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAccumulator sum = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + sum += mS[idx_K]; + } + + ElementAccumulator scale = 1.0f / sum; + + for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + ElementAccumulator eV = mV(idx_K + offset_K, idx_D, idx_L); + ElementAccumulator eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(acc * scale); + } + + if (threadIdx.x == 0) { + mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeIn, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Mask +> +void fmha_reference( + ProblemShapeIn problem_shape_in, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Mask mask) { + + using namespace cute; + + dim3 grid(size<0>(mO), size<2>(mO), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * int(sizeof(typename TensorLSE::value_type)); + fmha_reference_kernel<<>>(problem_shape_in, mQ, mK, mV, mO, mLSE, mask); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp new file mode 100644 index 0000000000..e4a01c8216 --- /dev/null +++ b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -0,0 +1,180 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DeviceAllocation { + T* ptr_ = nullptr; + size_t offset_ = 0; + size_t size_ = 0; + + DeviceAllocation(DeviceAllocation const&) = delete; + DeviceAllocation& operator=(DeviceAllocation const&) = delete; + + DeviceAllocation() = default; + DeviceAllocation(size_t size) { reset(size); } + ~DeviceAllocation() { reset(); } + + void reset(size_t size, size_t offset=0) { + reset(); + auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); + assert(ret == cudaSuccess); + size_ = size; + offset_ = offset; + } + + T* get() { + return ptr_ + offset_; + } + + const T* get() const { + return ptr_ + offset_; + } + + void reset() { + if (ptr_ != nullptr) { + auto ret = cudaFree(ptr_); + assert(ret == cudaSuccess); + } + } + + size_t size() const { return size_; } + + void copy_from_host(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } + + void copy_from_device(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } +}; + +template +__global__ void reference_abs_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + double diff = fabs(data[i] - data_ref[i]); + if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_abs_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_abs_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 2524378a62..21166302ab 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -146,6 +146,14 @@ foreach(EXAMPLE 64_ada_fp8_gemm_grouped 65_distributed_gemm 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling + 70_blackwell_gemm + 71_blackwell_gemm_with_collective_builder + 72_blackwell_narrow_precision_gemm + 73_blackwell_gemm_preferred_cluster + 74_blackwell_gemm_streamk + 75_blackwell_grouped_gemm + 76_blackwell_conv + 77_blackwell_fmha ) add_subdirectory(${EXAMPLE}) diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000..dddfa4c3e6 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,293 @@ +# CUTLASS - Programming Examples + +* [00_basic_gemm](00_basic_gemm/) + + launches a basic GEMM with single precision inputs and outputs + +* [01_cutlass_utilities](01_cutlass_utilities/) + + demonstrates CUTLASS Utilities for allocating and initializing tensors + +* [02_dump_reg_smem](02_dump_reg_smem/) + + debugging utilities for printing register and shared memory contents + +* [03_visualize_layout](03_visualize_layout/) + + utility for visualizing all layout functions in CUTLASS + +* [04_tile_iterator](04_tile_iterator/) + + example demonstrating an iterator over tiles in memory + +* [05_batched_gemm](05_batched_gemm/) + + example demonstrating CUTLASS's batched strided GEMM operation + +* [06_splitK_gemm](06_splitK_gemm/) + + example demonstrating CUTLASS's Split-K parallel reduction kernel + +* [07_volta_tensorop_gemm](07_volta_tensorop_gemm/) + + example demonstrating mixed precision GEMM using Volta Tensor Cores + +* [08_turing_tensorop_gemm](08_turing_tensorop_gemm/) + + example demonstrating integer GEMM using Turing Tensor Cores + +* [09_turing_tensorop_conv2dfprop](09_turing_tensorop_conv2dfprop/) + + example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores + +* [10_planar_complex](10_planar_complex/) + + example demonstrating planar complex GEMM kernels + +* [11_planar_complex_array](11_planar_complex_array/) + + example demonstrating planar complex kernels with batch-specific problem sizes + +* [12_gemm_bias_relu](12_gemm_bias_relu/) + + example demonstrating GEMM fused with bias and relu + +* [13_two_tensor_op_fusion](13_two_tensor_op_fusion/) + + example demonstrating two GEMMs or convolutions fused in one kernel + +* [14_ampere_tf32_tensorop_gemm](14_ampere_tf32_tensorop_gemm/) + + example demonstrating FP32 GEMM with implicit TF32 conversion + +* [15_ampere_sparse_tensorop_gemm](15_ampere_sparse_tensorop_gemm/) + + example demonstrating usage of Sparse Tensor cores + +* [16_ampere_tensorop_conv2dfprop](16_ampere_tensorop_conv2dfprop/) + + example demonstrating forward convolution on tensors of layout NHWC + +* [17_fprop_per_channel_bias](17_fprop_per_channel_bias/) + + example demonstrating convolution fused with per channel bias and relu + +* [18_ampere_fp64_tensorop_affine2_gemm](18_ampere_fp64_tensorop_affine2_gemm/) + + example demonstrating Affine-2 GEMM + +* [19_tensorop_canonical](19_tensorop_canonical/) + + Canonical GEMM using tensor cores + +* [20_simt_canonical](20_simt_canonical/) + + Canonical GEMM using SIMT + +* [21_quaternion_gemm](21_quaternion_gemm/) + + example demonstrating Quaternion GEMM computations + +* [22_quaternion conv](22_quaternion_conv/) + + example demonstrating Quaternion convolution + +* [23_ampere_gemm_operand_reduction_fusion](23_ampere_gemm_operand_reduction_fusion/) + + example demonstrating how to reduce one of the operands of the GEMM along the k-dimension when computing GEMM + +* [24_gemm_grouped](24_gemm_grouped/) + + example demonstrating batch of GEMM operations with distinct problem sizes + +* [25_ampere_fprop_mainloop_fusion](25_ampere_fprop_mainloop_fusion/) + + example demonstrating fusing activation's per channel scale+bias+relu into the fgrad mainloop + +* [26_ampere_wgrad_mainloop_fusion](26_ampere_wgrad_mainloop_fusion/) + + example demonstrating fusing activation's per channel scale+bias+relu into the wgrad mainloop + +* [27_ampere_3xtf32_fast_accurate_tensorop_gemm](27_ampere_3xtf32_fast_accurate_tensorop_gemm/) + + example demonstrating emulation of a fast accurate SGEMM with TF32 operations + +* [28_ampere_3xtf32_fast_accurate_tensorop_fprop](28_ampere_3xtf32_fast_accurate_tensorop_fprop/) + + example demonstrating emulation of a fast accurate FP32 convolution with TF32 operation + +* [29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm](29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/) + + example demonstrating emulation of a fast accurate CGEMM with TF32 operation + +* [30_wgrad_split_k](30_wgrad_split_k/) + + example demonstrating how to compute conv2d gradient with respect to weight (wgrad) together with split-K + +* [31_basic_syrk](31_basic_syrk/) + + example demonstrating Symmetric Rank-K update + +* [32_basic_trmm](32_basic_trmm/) + + example demonstrating Triangular Matrix-Matrix multiplication + +* [33_ampere_3xtf32_tensorop_symm](33_ampere_3xtf32_tensorop_symm/) + + example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation + +* [34_transposed_conv2d](34_transposed_conv2d/) + + example demonstrating how to compute 2d transposed convolution, also known as deconvolution, using CUTLASS conv2d Dgrad kernels + +* [35_gemm_softmax](35_gemm_softmax/) + + example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores + +* [36_gather_scatter_fusion](36_gather_scatter_fusion/) + + example demonstrating fuses gather before GEMM and scatter after GEMM into the same GEMM kernel + +* [37_gemm_layernorm_gemm_fusion](37_gemm_layernorm_gemm_fusion/) + + example demonstrating fuses gemm->layernorm->gemm into one kernel. + +* [38_syr2k_grouped](38_syr2k_grouped/) + + example demonstrating a batch of SYR2K operations with distinct problem sizes + +* [39_gemm_permute](39_gemm_permute/) + + example demonstrating batched GEMM operations with output results permuted as reshaped tensors + +* [40_cutlass_py](40_cutlass_py/) + + example demonstrating CUTLASS with Python interface + +* [41_multi_head_attention](41_multi_head_attention/) + + example demonstrating attention example with non-fixed sequence length input + +* [42_ampere_tensorop_group_conv](42_ampere_tensorop_group_conv/) + + example demonstrating how to run group convolution kernels using functions and data structures provided by CUTLASS using tensor cores + +* [43_ell_block_sparse_gemm](43_ell_block_sparse_gemm/) + + example demonstrating a Block-Ell sparse gemm + +* [44_fused_multi_head_attention](44_fused_multi_head_attention/) + + example demonstrating fused multihead attention (fixed & variable) using shared memory + +* [45_dual_gemm](45_dual_gemm/) + + example demonstrating how to fuse two GEMMs sharing the same left input matrix into one kernel + +* [46_depthwise_simt_conv2dfprop](46_depthwise_simt_conv2dfprop/) + + example demonstrating depthwise 2d convolution kernels using functions and data structures provided by CUTLASS using SIMT instruction + +* [47_ampere_gemm_universal_streamk](47_ampere_gemm_universal_streamk/) + + example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the + "classic data-parallel" and "Split-K" decompositions. + +* [48_hopper_warp_specialized_gemm](48_hopper_warp_specialized_gemm/) + + Simple tensorop GEMM example using CUTLASS 3.0 APIs targeting NVIDIA Hopper architecture + +* [49_hopper_gemm_schedules_with_collective_builder](49_hopper_gemm_schedules_with_collective_builder/) + + Hopper GEMM example leveraging collective operation builders to showcase the builder API and the various kernel scheduled supported in CUTLASS 3.0 such as warp specialized persistent mainloops. + +* [50_hopper_gemm_with_epilogue_swizzle](50_hopper_gemm_with_epilogue_swizzle/) + + Hopper GEMM example to create a GEMM kernel with custom a collective mainloop and a custom vectorized epilogue. + +* [51_hopper_gett](51_hopper_gett/) + + Hopper GETT example illustrating the ease with which GETTs can be run due to CUTLASS 3.0's unified micro-kernels and CuTe's hierarchical layouts. + +* [52_hopper_gather_scatter_fusion](52_hopper_gather_scatter_fusion/) + + Hopper example that fuses gather before GEMM and scatter after GEMM into the same kernel + +* [53_hopper_gemm_permute](53_hopper_gemm_permute/) + + Hopper example demonstrating the fusion of tensor permutation operations with a GEMM kernel + +* [54_hopper_fp8_warp_specialized_gemm](54_hopper_fp8_warp_specialized_gemm/) + + Hopper example of instantiating and running an FP8 GEMM kernel + +* [55_hopper_mixed_dtype_gemm](55_hopper_mixed_dtype_gemm/) + + Hopper GEMM example with different A and B data types using CUTLASS 3.x APIs for DL kernels with fused dequantization. + +* [56_hopper_ptr_array_batched_gemm](56_hopper_ptr_array_batched_gemm/) + + Hopper Ptr-Array Batched GEMM example using CUTLASS 3.x API. + +* [57_hopper_grouped_gemm](57_hopper_grouped_gemm/) + + Hopper Grouped GEMM using CUTLASS 3.x API. + +* [58_ada_fp8_gemm](58_ada_fp8_gemm/) + + Ada GEMM kernel targetting Ada FP8 tensor cores via the CUTLASS 2.x API. + +* [59_ampere_gather_scatter_conv](59_ampere_gather_scatter_conv/) + + CuTe and CUTLASS 3.x based Ampere convolution fprop kernel capable of operating on both affine and gather/scatter tensors, + showing how kernel authors can re-use CUTLASS 3.x collectives in their custom kernels. + +* [61_hopper_gemm_with_topk_and_softmax](61_hopper_gemm_with_topk_and_softmax/) + + Hopper GEMM kernel with Top-K and softmax epilogue fusion. + +[//]: # + +* [70_blackwell_gemm](70_blackwell_gemm) + + Simple dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. + +* [71_blackwell_gemm_with_collective_builder](71_blackwell_gemm_with_collective_builder) + + Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules and epilogue visitor tree (EVT) construction + +* [72a_blackwell_narrow_precision_gemm](72a_blackwell_narrow_precision_gemm) + + Block-scaled dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. + +* [73_blackwell_gemm_preferred_cluster](73_blackwell_gemm_preferred_cluster/) + + Blackwell SM100 GEMM kernel with preferred cluster feature. + +* [74_blackwell_gemm_streamk](74_blackwell_gemm_streamk/) + + Blackwell SM100 GEMM kernel using the Stream-K scheduler + +* [75_blackwell_grouped_gemm](75_blackwell_grouped_gemm) + + Blackwell SM100 grouped GEMM kernel + +* [76_blackwell_conv](76_blackwell_conv/) + + Simple convolution(fprop/dgrad/wgrad) example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. + +* [77_blackwell_fmha](77_blackwell_fmha) + + Blackwell SM100 FMHA kernel + +[//]: # + +# CuTe - Programming Examples + +Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/). + +Additionally, CuTe's core layout and layout algebra have their own test cases within [cutlass/test/unit/cute/core/](../test/unit/cute/core/) that users might find useful as examples of CuTe. + +# Python Interface Examples + +Examples leveraging CUTLASS's [Python interface](../python/README.md) are located in [cutlass/examples/python](python/). diff --git a/include/cute/arch/cluster_sm100.hpp b/include/cute/arch/cluster_sm100.hpp new file mode 100755 index 0000000000..0bcf19b0f2 --- /dev/null +++ b/include/cute/arch/cluster_sm100.hpp @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// + +#pragma once + +#include + + +namespace cute { + + +// +// Cluster launch utility +// +CUTE_HOST +bool +initialize_preferred_cluster_launch(void const* const kernel_function, + dim3 const& grid_dims, + dim3 const& cluster_dims_preferred, + dim3 const& cluster_dims_fallback) +{ + // + // Validate cluster_dims + // + + // Total number of cluster cannot be greater than 32 (hardware requirement) + if (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z <= 0 || + cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z > 32) { + std::cout << "Invalid preferred cluster dimensions: Attempting to init preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z + << ") [" << (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z) << "] which must be within (0,32]." << std::endl; + return false; + } + + // Total number of cluster cannot be greater than 32 (hardware requirement) + if (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z <= 0 || + cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z > 32) { + std::cout << "Invalid cluster dimensions: Attempting to init fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") [" << (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z) << "] which must be within (0,32]." << std::endl; + return false; + } + + // Total grid dimensions must be within (2^32, 2^16, 2^16) + if (grid_dims.y > (1 << 16) || grid_dims.z > (1 << 16)) { + std::cout << "Invalid grid dimensions: Attempting to init grid dimensions (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z + << ") which must be within (2^32, 2^16, 2^16)." << std::endl; + return false; + } + + // grid_dims should be divisible by cluster_dims_preferred + if (grid_dims.x % cluster_dims_preferred.x != 0 || + grid_dims.y % cluster_dims_preferred.y != 0 || + grid_dims.z % cluster_dims_preferred.z != 0) { + std::cout << "Invalid grid dimensions: Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z + << ") does not divide Grid (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z << ")." << std::endl; + return false; + } + + // cluster_dims_preferred should be divisible by cluster_dims_fallback + if (cluster_dims_preferred.x % cluster_dims_fallback.x != 0 || + cluster_dims_preferred.y % cluster_dims_fallback.y != 0 || + cluster_dims_preferred.z % cluster_dims_fallback.z != 0) { + std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") does not divide Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ")." << std::endl; + return false; + } + + // Both cluster dimenions should have the same depth + if (cluster_dims_preferred.z != cluster_dims_fallback.z) { + std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") and Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ") does not have the same depth." << std::endl; + return false; + } + + return true; +} +} // end namespace cute diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index 4af01e339f..a81b4e33e3 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -48,3 +48,42 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) + #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + + + diff --git a/include/cute/arch/copy_sm100.hpp b/include/cute/arch/copy_sm100.hpp new file mode 100644 index 0000000000..19b13841a1 --- /dev/null +++ b/include/cute/arch/copy_sm100.hpp @@ -0,0 +1,7567 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// LDSM PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t tmp0, tmp1; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8 {%0, %1}, [%2];\n" + : "=r"(reinterpret_cast(tmp0)), "=r"(reinterpret_cast(tmp1)) + : "r"(smem_int_ptr)); + // RefLayout of ldmatrix.m16n16.x1.trans won't match stmatrix.m16n8.x2.trans without additional transformations + // Do this here so we don't need to add an additional reg to reg copy at the collective layer + uchar4& tmp0_ = reinterpret_cast(tmp0); + uchar4& tmp1_ = reinterpret_cast(tmp1); + uchar4 dst0_{tmp0_.x, tmp0_.y, tmp1_.x, tmp1_.y}; + uchar4 dst1_{tmp0_.z, tmp0_.w, tmp1_.z, tmp1_.w}; + dst0 = reinterpret_cast(dst0_); + dst1 = reinterpret_cast(dst1_); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x16_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t tmp0, tmp1, tmp2, tmp3; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(tmp0)), "=r"(reinterpret_cast(tmp1)), + "=r"(reinterpret_cast(tmp2)), "=r"(reinterpret_cast(tmp3)) + : "r"(smem_int_ptr)); + uchar4& tmp0_ = reinterpret_cast(tmp0); + uchar4& tmp1_ = reinterpret_cast(tmp1); + uchar4& tmp2_ = reinterpret_cast(tmp2); + uchar4& tmp3_ = reinterpret_cast(tmp3); + uchar4 dst0_{tmp0_.x, tmp0_.y, tmp1_.x, tmp1_.y}; + uchar4 dst1_{tmp0_.z, tmp0_.w, tmp1_.z, tmp1_.w}; + uchar4 dst2_{tmp2_.x, tmp2_.y, tmp3_.x, tmp3_.y}; + uchar4 dst3_{tmp2_.z, tmp2_.w, tmp3_.z, tmp3_.w}; + dst0 = reinterpret_cast(dst0_); + dst1 = reinterpret_cast(dst1_); + dst2 = reinterpret_cast(dst2_); + dst3 = reinterpret_cast(dst3_); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +struct SM100_SU4_DU8x16_x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b6x16_p32 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU4_DU8x16_x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b6x16_p32 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU4_DU8x16_x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b6x16_p32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// STSM PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x4_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x8_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x16_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTCCP PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { +// 128 data path lanes, 256-bit pattern, 1cta mode +struct SM100_UTCCP_128dp256bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.128x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 128 data path lanes, 256-bit pattern, 2cta mode +struct SM100_UTCCP_128dp256bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.128x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +struct SM100_UTCCP_128dp128bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.128x128b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +struct SM100_UTCCP_128dp128bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.128x128b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + + +// 4 data path lanes, 256-bit pattern, 1cta mode +struct SM100_UTCCP_4dp256bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.4x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4 data path lanes, 256-bit pattern, 2cta mode +struct SM100_UTCCP_4dp256bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.4x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4x32 data path lanes (broadcast), 128-bit pattern, 1cta mode +struct SM100_UTCCP_4x32dp128bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.32x128b.warpx4 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4x32 data path lanes (broadcast), 128-bit pattern, 2cta mode +struct SM100_UTCCP_4x32dp128bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.32x128b.warpx4 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast like 4x32dp), 128-bit pattern, 1cta mode +struct SM100_UTCCP_2x64dp128bitlw0213_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.64x128b.warpx2::02_13 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast like 4x32dp), 128-bit pattern, 2cta mode +struct SM100_UTCCP_2x64dp128bitlw0213_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.64x128b.warpx2::02_13 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast seperately in upper and lower 64dp), 128-bit pattern, 1cta mode +// data_row[0:31] -> DP[0:63] +// data_row[32:63] -> DP[64:127] +struct SM100_UTCCP_2x64dp128bitlw0123_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.64x128b.warpx2::01_23 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast seperately in upper and lower 64dp), 128-bit pattern, 2cta mode +// data_row[0:31] -> DP[0:63] +// data_row[32:63] -> DP[64:127] +struct SM100_UTCCP_2x64dp128bitlw0123_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.64x128b.warpx2::01_23 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_LOAD PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp256b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x1.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp256b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x2.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp256b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x4.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp256b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp256b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp256b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp128b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x1.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp128b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp128b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x4.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x4.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp128b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp128b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp128b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp128b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp64b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x1.pack::16b.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp64b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x2.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp64b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp64b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp64b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp64b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp64b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_16dp64b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x1.b32" + "{%0}," + "[%1], 1;\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x1.pack::16b.b32" + "{%0}," + "[%1], 2;\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp32b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x2.b32" + "{%0, %1}," + "[%2], 2;\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x2.pack::16b.b32" + "{%0, %1}," + "[%2], 4;\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp32b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x4.b32" + "{%0, %1, %2, %3}," + "[%4], 4;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4], 8;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp32b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8], 8;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8], 16;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp32b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16], 16;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16], 32;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32], 32;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32], 64;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp32b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64], 64;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64], 128;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_16dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128], 128;\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128], 256;\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_32dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x1.pack::16b.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_32dp32b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x2.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_32dp32b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_32dp32b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_32dp32b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_32dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_32dp32b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_32dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_STORE PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp256b1x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b1x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp256b2x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b2x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp256b4x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x4.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b4x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp256b8x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b8x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp256b16x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b16x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp256b32x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b32x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp128b1x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b1x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp128b2x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b2x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp128b4x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b4x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp128b8x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b8x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp128b16x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b16x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp128b32x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b32x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp128b64x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b64x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp64b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.unpack::16b.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp64b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp64b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp64b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp64b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp64b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp64b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_16dp64b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x128.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x128.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.b32" + "[%0] , 1," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.unpack::16b.b32" + "[%0] , 2," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp32b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.b32" + "[%0] , 2," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.unpack::16b.b32" + "[%0] , 4," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp32b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.b32" + "[%0] , 4," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.unpack::16b.b32" + "[%0] , 8," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp32b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.b32" + "[%0] , 8," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.unpack::16b.b32" + "[%0] , 16," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp32b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x16.b32" + "[%0] , 16," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x16.unpack::16b.b32" + "[%0] , 32," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp32b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x32.b32" + "[%0] , 32," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x32.unpack::16b.b32" + "[%0] , 64," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp32b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x64.b32" + "[%0] , 64," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x64.unpack::16b.b32" + "[%0] , 128," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_16dp32b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x128.b32" + "[%0] , 128," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x128.unpack::16b.b32" + "[%0] , 256," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_32dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.unpack::16b.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_32dp32b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_32dp32b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_32dp32b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_32dp32b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_32dp32b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_32dp32b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_32dp32b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x128.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cute/arch/copy_sm100_tma.hpp b/include/cute/arch/copy_sm100_tma.hpp new file mode 100644 index 0000000000..f69cbffdea --- /dev/null +++ b/include/cute/arch/copy_sm100_tma.hpp @@ -0,0 +1,664 @@ +/*************************************************************************************************** + * Copyright (c) 2020 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include + +#include +#include +namespace cute +{ + +constexpr uint32_t Sm100MmaPeerBitMask = 0xFEFFFFFF; +constexpr uint64_t Sm100MemDescDefault = uint64_t(0x1000000000000000); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// UTMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) +uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM0_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +struct SM100_TMA_2SM_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM100_TMA_2SM_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { + return SM100_TMA_2SM_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { + return SM100_TMA_2SM_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9, %10;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11, %12;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +} // end namespace cute diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 040c9f4d36..a11470d8dd 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -140,6 +140,11 @@ enum class SmemSwizzleBits : uint8_t { enum class SmemSwizzleBase : uint8_t { SWIZZLE_BASE_16B = 0, + + SWIZZLE_BASE_32B = 1, + SWIZZLE_BASE_32B_FLIP_8B = 2, + SWIZZLE_BASE_64B = 3, + }; enum class OOBFill : uint8_t { @@ -184,6 +189,14 @@ enum class CacheHintSm90 : uint64_t { EVICT_LAST = 0x14F0000000000000, }; + +enum class CacheHintSm100 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + + #if (__CUDACC_VER_MAJOR__ >= 12) #if !defined(__CUDACC_RTC__) @@ -195,6 +208,7 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else @@ -205,6 +219,18 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } } @@ -221,9 +247,21 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { case SmemSwizzleBits::B64: assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); return CU_TENSOR_MAP_SWIZZLE_64B; + #if (0) case SmemSwizzleBits::B128: assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits."); return CU_TENSOR_MAP_SWIZZLE_128B; + + #else + case SmemSwizzleBits::B128: + switch (b) { + default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B; + case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; + case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B; + } + #endif + } } diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index 60f320e3a7..a4bc379472 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -1157,6 +1157,17 @@ tma_store_arrive() { #endif } + +CUTE_HOST_DEVICE static void +tma_desc_commit_group() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + // Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete template CUTE_HOST_DEVICE static void @@ -1173,6 +1184,22 @@ tma_store_wait() { #endif } + +// Wait until all TMA descriptor previously issued are safe to be modified after tma_desc_commit_group() +CUTE_HOST_DEVICE static void +tma_desc_wait_group() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + //////////////////////////////////////////////////////////////////////////////////////////////////// /// TMA_REDUCE_ADD : Initiates a TMA reduce-add from shared memory to global memory //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm100.hpp b/include/cute/arch/mma_sm100.hpp new file mode 100644 index 0000000000..2fa532d2ef --- /dev/null +++ b/include/cute/arch/mma_sm100.hpp @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include +#include + +namespace cute { + +} // namespace cute diff --git a/include/cute/arch/mma_sm100_desc.hpp b/include/cute/arch/mma_sm100_desc.hpp new file mode 100644 index 0000000000..57934f240f --- /dev/null +++ b/include/cute/arch/mma_sm100_desc.hpp @@ -0,0 +1,652 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include + +#include + +#include +#include // cute::array + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// UMMA Descriptor and utilities + +// UMMA enums and utilities +namespace UMMA +{ + +enum class Major : uint8_t { + K = 0, + MN = 1 +}; + +enum class ScaleIn : uint8_t { + One = 0, + Neg = 1 +}; + +enum class ScaleOut : uint8_t { + Zero = 0, + One = 1 +}; + +enum class Saturate : uint8_t { + False = 0, + True = 1 +}; + +enum class LayoutType : uint8_t { + SWIZZLE_NONE = 0, + SWIZZLE_128B_BASE32B = 1, + SWIZZLE_128B = 2, + SWIZZLE_64B = 4, + SWIZZLE_32B = 6 +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::SWIZZLE_NONE: return "SWIZZLE_NONE"; + case LayoutType::SWIZZLE_128B_BASE32B: return "SWIZZLE_128B_BASE32B"; + case LayoutType::SWIZZLE_128B: return "SWIZZLE_128B"; + case LayoutType::SWIZZLE_64B: return "SWIZZLE_64B"; + case LayoutType::SWIZZLE_32B: return "SWIZZLE_32B"; + } + return nullptr; +} + +union SmemDescriptor +{ + uint64_t desc_ = 0; + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + uint16_t stride_byte_offset_ : 14, version_ : 2; // 14 bits [0,14), 2 bits [14,16) + // base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53). + uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1, : 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused + // layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0, SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4, SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5, N/A = 7 + uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8) + }; + // Seperate the field, as we may only update one part of desc + struct { + uint32_t lo; + uint32_t hi; + }; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } +}; + +enum class F16F32Format : uint8_t { + F16 = 0, + BF16 = 1, + TF32 = 2, +}; + +CUTE_HOST_DEVICE char const* to_string(F16F32Format const& t) { + switch (t) { + case F16F32Format::F16: return "F16"; + case F16F32Format::BF16: return "BF16"; + case F16F32Format::TF32: return "TF32"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr F16F32Format to_F16F32Format() { + if constexpr (is_same_v) { return F16F32Format::F16; } else + if constexpr (is_same_v) { return F16F32Format::BF16; } else + if constexpr (is_same_v) { return F16F32Format::TF32; } else + { static_assert(sizeof(T) == 0, "Unknown type for F16F32Format"); } +} + +enum class S8Format : uint8_t { + UINT8 = 0, + INT8 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(S8Format const& t) { + switch (t) { + case S8Format::UINT8: return "UINT8"; + case S8Format::INT8: return "INT8"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr S8Format to_S8Format() { + if constexpr (is_same_v) { return S8Format::UINT8; } else + if constexpr (is_same_v) { return S8Format::INT8; } else + { static_assert(sizeof(T) == 0, "Unknown type for S8Format"); } +} + +enum class MXF8F6F4Format : uint8_t { + E4M3 = 0, + E5M2 = 1, + E2M3 = 3, + E3M2 = 4, + E2M1 = 5, + INVALID = 7 // an invalid datatype for runtime proxy type +}; + +CUTE_HOST_DEVICE char const* to_string(MXF8F6F4Format const& t) { + switch (t) { + case MXF8F6F4Format::E4M3: return "E4M3"; + case MXF8F6F4Format::E5M2: return "E5M2"; + case MXF8F6F4Format::E2M3: return "E2M3"; + case MXF8F6F4Format::E3M2: return "E3M2"; + case MXF8F6F4Format::E2M1: return "E2M1"; + case MXF8F6F4Format::INVALID: return "INVALID"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr MXF8F6F4Format to_MXF8F6F4Format() { + if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for MXF8F6F4Format"); } +} + +enum class MXF4Format : uint8_t { + E2M1 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(MXF4Format const& t) { + switch (t) { + case MXF4Format::E2M1: return "E2M1"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr MXF4Format to_MXF4Format() { + if constexpr (is_same_v) { return MXF4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for MXF4Format"); } +} + +enum class ScaleFormat : uint8_t { + UE4M3 = 0, + UE8M0 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(ScaleFormat const& t) { + switch (t) { + case ScaleFormat::UE4M3: return "UE4M3"; + case ScaleFormat::UE8M0: return "UE8M0"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr ScaleFormat to_ScaleFormat() { + if constexpr (is_same_v) { return ScaleFormat::UE4M3; } else + if constexpr (is_same_v) { return ScaleFormat::UE8M0; } else + { static_assert(sizeof(T) == 0, "Unknown type for ScaleFormat"); } +} + +enum class CFormat : uint8_t { + F16 = 0, + F32 = 1, + S32 = 2, +}; + +CUTE_HOST_DEVICE char const* to_string(CFormat const& t) { + switch (t) { + case CFormat::F16: return "F16"; + case CFormat::F32: return "F32"; + case CFormat::S32: return "S32"; + } + return nullptr; +} + +enum class MaxShift : uint8_t { + NoShift = 0, + MaxShift8 = 1, + MaxShift16 = 2, + MaxShift32 = 3 +}; + +enum class BMatrixBufferId : uint8_t { + Zero = 0u, + One = 1u, + Two = 2u, + Three = 3u +}; + +enum class BMatrixBufferReuse : uint8_t { + Keep = 1u, + Reuse = 2u, + ReuseAndKeep = 3u +}; + +// using MaskAndShiftB = uint32_t[2]; +union MaskAndShiftB +{ + uint32_t uri[2]; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint8_t start_count_ [4]; // bit [ 0:32) : 8 bits each. Specifies the start count for mask generation. + uint32_t first_span_ : 4, // bit [32:36) : 1 bit each. 0 = start where B is used. 1 = start with where B is skipped(0 value is used). + : 3, // + nzm_ : 1, // bit [39:40) : 0 = Enable the mask. 1 = Disable the mask. + skip_span_ : 8, // bit [40:48) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where 0 value is used. + use_span_ : 8, // bit [48:55) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where B matrix data is used. + shift_ : 6, // bit [56:62) : Shift value for B matrix data. + : 2; + }; +}; + +template +CUTE_HOST_DEVICE constexpr auto +make_column_zero_mask(ShapeType conv_q, int32_t cta_coord_q, int32_t num_pixels_skip_left) { + + static_assert(cute::is_same_v || cute::is_integral::value); + + cute::array column_zero_masks{}; + + static_assert(FLT_S == 3, "Filter size not supported."); + constexpr int MAX_USE_SPAN_COUNT = 256; + constexpr int MAX_SKIP_SPAN_COUNT = 256; + + // conv_q_int used for non-divmod case (add/minus/..) + // conv_q used for divmod case (div/mod/...) + int32_t conv_q_int = int(conv_q); + auto [_, cta_q] = divmod(cta_coord_q * CTA_N, conv_q); + + int step_q = CTA_M == 128 ? CTA_N / 1 + : CTA_M == 64 ? CTA_N / 2 + : CTA_M == 32 ? CTA_N / 4 + : 0; + + for (int mask_iter = 0; mask_iter < int(CTA_N / step_q); ++mask_iter) { + + for (int s_iter = 0; s_iter < FLT_S; s_iter += 1) { + + int32_t skip_span{0}, use_span{0}, nzm{1}, first_span{0}, start_count{0}, shift{0}; + + shift = s_iter; + + // Examples for CZM setting + // CASE0: (skip_span_ < 0) + // | padding |<- conv_q ->| + // |skip_span_|<- use_span ->|skip_span_| + // -skip_span 0 ^cta_q conv_q-1 + // 0 ^index + // + // CASE1: (skip_span_ > 0) + // |<- conv_q ->| + // |skip_span_|<- use_span ->|skip_span_| + // 0 ^cta_q conv_q-1 + // 0 ^index + // + // line 0 an input vector from 0 to conv_q with the padding + // line 1 shows the different spans we need to skip or load + // lines 2-3 show the different coordinates of different boundaries. + // CTQ_q is the coordinate of the present cta. + + int32_t skip_span_ = num_pixels_skip_left - shift; + int32_t index{0}; + if (skip_span_ > 0) { + auto [_, index_mod] = divmod(cta_q, conv_q); + index = index_mod; + } else if (skip_span_ < 0) { + auto [_, index_mod] = divmod((cta_q - skip_span_), conv_q); + index = index_mod; + } else { + nzm = 0; + } + skip_span = cute::max(cute::abs(skip_span_), 1); + use_span = cute::min(conv_q_int - static_cast(skip_span), MAX_USE_SPAN_COUNT); + if (use_span > 0) { + first_span = index >= skip_span ? 0 : 1; + if ((first_span == 0) && (index + CTA_N < conv_q_int + skip_span)) { + nzm = 0; + } else { + start_count = first_span == 0 ? (use_span - (conv_q_int - index)) : index; + } + } else { + skip_span = MAX_SKIP_SPAN_COUNT; + use_span = 1; + first_span = 1; + start_count = 0; + } + + column_zero_masks[s_iter].start_count_[mask_iter] = start_count; + column_zero_masks[s_iter].first_span_ |= first_span << mask_iter; + column_zero_masks[s_iter].nzm_ |= nzm; + column_zero_masks[s_iter].skip_span_ = skip_span - 1; + column_zero_masks[s_iter].use_span_ = use_span - 1; + column_zero_masks[s_iter].shift_ = shift; + + } + + cta_q += step_q; + } + + return column_zero_masks; +} + +template +CUTE_HOST_DEVICE constexpr auto to_UMMAFormat() { + if constexpr (is_same_v) { return F16F32Format::F16; } else + if constexpr (is_same_v) { return F16F32Format::BF16; } else + if constexpr (is_same_v) { return F16F32Format::TF32; } else + if constexpr (is_same_v) { return S8Format::UINT8; } else + if constexpr (is_same_v) { return S8Format::INT8; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + + if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else + + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else + if constexpr (is_same_v) { return MXF4Format::E2M1; } else + + { static_assert(sizeof(T) == 0, "Unknown type for UMMAFormat"); } +} + +template +CUTE_HOST_DEVICE constexpr CFormat to_CFormat() { + if constexpr (is_same_v) { return CFormat::F16; } else + if constexpr (is_same_v) { return CFormat::F32; } else + if constexpr (is_same_v) { return CFormat::S32; } else + { static_assert(sizeof(T) == 0, "Unknown type for CFormat"); } +} + +union InstrDescriptor +{ + uint32_t desc_; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2 + sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4 + saturate_ : 1, // bit [ 3, 4) : 0 = no saturate. 1 = saturate. 1 value valid only for S8 + c_format_ : 2, // bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32 + : 1, // + a_format_ : 3, // bit [ 7,10) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean + b_format_ : 3, // bit [10,13) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean + a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats + : 1, // + m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) + : 1, // + max_shift_ : 2; // bit [30,32) : Maximum shift for WS instruction. Encoded as follows: 0 = no shift, 1 = maximum shift of 8, 2 = maximum shift of 16, 3 = maximum shift of 32. + }; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr explicit + operator uint32_t() const noexcept { return desc_; } +}; + +union InstrDescriptorBlockScaled +{ + uint32_t desc_; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2 + sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4 + : 1, // + b_sf_id_ : 2, // bit [ 4, 6) : Matrix B Scale Factor ID + : 1, // + a_format_ : 3, // bit [ 7, 9) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + b_format_ : 3, // bit [10,12) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats + scale_format_ : 1, // bit [23,24) : 0=E4M3, 1=E8M0 + m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) + a_sf_id_ : 2, // bit [29,31) : Matrix A Scale Factor ID + : 1; // + }; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr + operator uint32_t() const noexcept { return desc_; } +}; + +template +CUTE_HOST_DEVICE constexpr +UMMA::InstrDescriptor +make_instr_desc() +{ + UMMA::InstrDescriptor desc_i = {}; + + desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.c_format_ = uint8_t(UMMA::to_CFormat()); + + desc_i.m_dim_ = (M >> 4); + desc_i.n_dim_ = (N >> 3); + + desc_i.a_major_ = uint8_t(a_major); + desc_i.b_major_ = uint8_t(b_major); + + desc_i.a_negate_ = uint8_t(a_neg); + desc_i.b_negate_ = uint8_t(b_neg); + desc_i.saturate_ = uint8_t(c_sat); + + desc_i.sparse_flag_ = is_sparse; // 1 = Sparse + desc_i.sparse_id2_ = 0; + + desc_i.max_shift_ = uint8_t(max_shift); + + return desc_i; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc(uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) { + UMMA::InstrDescriptor desc_i = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat, is_sparse, + max_shift>(); + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc(UMMA::InstrDescriptor desc_i, uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) +{ + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE constexpr +UMMA::InstrDescriptorBlockScaled +make_instr_desc_block_scaled() +{ + UMMA::InstrDescriptorBlockScaled desc_i = {}; + + desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat()); + + desc_i.scale_format_ = uint8_t(UMMA::to_ScaleFormat()); + desc_i.a_sf_id_ = 0; + desc_i.b_sf_id_ = 0; + + desc_i.m_dim_ = (M >> 4); + desc_i.n_dim_ = (N >> 3); + + desc_i.a_major_ = uint8_t(a_major); + desc_i.b_major_ = uint8_t(b_major); + + desc_i.a_negate_ = uint8_t(a_neg); + desc_i.b_negate_ = uint8_t(b_neg); + desc_i.sparse_flag_ = is_sparse; // 1 = Sparse + desc_i.sparse_id2_ = 0; + + // Below would bring some warnings. +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Wconversion" +#endif + return desc_i; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc_block_scaled(uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr, + uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u) +{ + UMMA::InstrDescriptorBlockScaled desc_i = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, + a_major, b_major, + a_neg, b_neg, + is_sparse>(); + + // The first 2-bits of TMEM address includes byte address. + desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30; + desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30; + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc_block_scaled(UMMA::InstrDescriptorBlockScaled desc_i, + uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr, + uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u) +{ + // The first 2-bits of TMEM address includes byte address. + desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30; + desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30; + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +} // end namespace UMMA +} // namespace cute diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp new file mode 100644 index 0000000000..26ef131c37 --- /dev/null +++ b/include/cute/arch/mma_sm100_umma.hpp @@ -0,0 +1,1074 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct SM100_MMA_TF32_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_TF32 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_TF32 N-mode size should be a multiple of 8 between 8 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print("UMMA TMEM addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + uint32_t mask[4] = {0, 0, 0, 0}; + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_TS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print("UMMA TMEM addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + uint32_t mask[4] = {0, 0, 0, 0}; + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32 N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + +#if 0 + if (thread0()) { + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("Umma TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS without SM100_MMA_TF32_2x1SM_SS"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16 N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + +#if 0 + if (thread0()) { + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("Umma TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) +#if 0 + if (thread0()) { + print(desc_i); + print("Umma TMEM-A addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("Umma TMEM-C addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_TS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) +#if 0 + if (thread0()) { + print(desc_i); + print("Umma TMEM-A addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("Umma TMEM-C addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_SS +{ + static_assert(is_same_v, "SM100_MMA_S8 result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_S8 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 8 between 8 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print("UMMA TMEM addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("Umma TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::i8 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print("Umma TMEM addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("Umma TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::i8 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +struct SM100_MMA_F8F6F4_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_MXF8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptorBlockScaled desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + print("Umma SFA TMEM addr: 0x%08x\n", tsfa_addr); + print("Umma SFB TMEM addr: 0x%08x\n", tsfb_addr); + print("===================================\n"); + } +#endif + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 8 between 8 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print("UMMA TMEM addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print("UMMA TMEM addr: 0x%08x\n", tmem_a); + print(reinterpret_cast(desc_b)); + print("UMMA TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptor desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("Umma TMEM addr: 0x%08x\n", tmem_c); + } +#endif + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_MXF8F6F4 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4 N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) +#if 0 + if (thread0()) { + UMMA::InstrDescriptorBlockScaled desc_i; + desc_i.desc_ = uint32_t(idescE >> 32); + print(desc_i); + print(reinterpret_cast(desc_a)); + print(reinterpret_cast(desc_b)); + print("Umma TMEM addr: 0x%08x\n", tmem_c); + print("Umma SFA TMEM addr: 0x%08x\n", tsfa_addr); + print("Umma SFB TMEM addr: 0x%08x\n", tsfb_addr); + print("===================================\n"); + } +#endif + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + + +template +struct SM100_MMA_MXF4_SS +{ + static_assert(M == 128, "SM100_MMA_MXF4 M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 16) || (VS == 32), "Vector size can only be 16 or 32."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { + if constexpr (VS == 16) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_SS (VS = 16) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_SS (VS = 32) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } + +}; + + + +template +struct SM100_MMA_MXF4_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_MXF4 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 16) || (VS == 32), "Vector size can only be 16 or 32."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { + if constexpr (VS == 16) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_2x1SM_SS (VS = 16) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_2x1SM_SS (VS = 32) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + + +} // end namespace cute diff --git a/include/cute/arch/simd_sm100.hpp b/include/cute/arch/simd_sm100.hpp new file mode 100644 index 0000000000..1c07a31e6d --- /dev/null +++ b/include/cute/arch/simd_sm100.hpp @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include +#include +#include + +namespace cute { + +CUTE_HOST_DEVICE +void +add(float2 & c, + float2 const& a, + float2 const& b) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("add.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +#else + add(c.x, a.x, b.x); + add(c.y, a.y, b.y); +#endif +} + +CUTE_HOST_DEVICE +void +mul(float2 & c, + float2 const& a, + float2 const& b) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("mul.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +#else + mul(c.x, a.x, b.x); + mul(c.y, a.y, b.y); +#endif +} + +CUTE_HOST_DEVICE +void +fma(float2 & d, + float2 const& a, + float2 const& b, + float2 const& c) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +#else + fma(d.x, a.x, b.x, c.x); + fma(d.y, a.y, b.y, c.y); +#endif +} + +} // namespace cute diff --git a/include/cute/arch/tmem_allocator_sm100.hpp b/include/cute/arch/tmem_allocator_sm100.hpp new file mode 100644 index 0000000000..2d2cac9d2f --- /dev/null +++ b/include/cute/arch/tmem_allocator_sm100.hpp @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// +#pragma once + +#include +#include +#include + +#include + +namespace cute::TMEM { + +// All operations of this class require that only a single warp uniformly participates +class Allocator1Sm { +public: + static constexpr int ColumnsPerAllocationSlice = 32; + static constexpr int Sm100TmemCapacityColumns = 512; + + __device__ Allocator1Sm() { } + + /** + * Performs a non-blocking allocation of TMEM. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre For repeated allocations, the same warp must be used to issue all allocations. + **/ + __device__ void + allocate(int num_columns, uint32_t* dst_ptr) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ + void + free(uint32_t tmem_ptr, int num_columns) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile( + "{\n\t" + "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(tmem_ptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ void + release_allocation_lock() { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;" ::); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +class Allocator2Sm { +public: + static constexpr int ColumnsPerAllocationSlice = 32; + static constexpr int Sm100TmemCapacityColumns = 512; + + __device__ Allocator2Sm() { } + + /** + * Performs a non-blocking allocation of TMEM. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to. + * Both CTAs _must_ provide the exact same dst_ptr for correctness. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre For repeated allocations, the same warp must be used to issue all allocations. + * @pre The 2 warps from participating CTAs have the same logical warp ID. + **/ + __device__ void + allocate(int num_columns, uint32_t* dst_ptr) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::2.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + /** + * Frees the TMEM corresponding to the pointer and slice count provided. + * Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices. + * @param tmem_ptr Base address of the TMEM address space being freed. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre The 2 warps from participating CTAs have the same logical warp ID. + * @returns true + **/ + __device__ + void + free(uint32_t tmem_ptr, int num_columns) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile( + "{\n\t" + "tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(tmem_ptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ + void + release_allocation_lock() { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::2.sync.aligned;" ::); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } +}; + +} // namespace cute::TMEM diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 612ef0b6e3..96ffbec230 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -751,14 +751,33 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and #include #include #include +#include + // Config #if (__CUDACC_VER_MAJOR__ >= 12) # define CUTE_COPY_ATOM_TMA_SM90_ENABLED +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED +#endif + + +#if (!defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#endif + +#if (!defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED)) +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED #endif + #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) #include #endif + +#if defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED) +#include +#endif + + //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp new file mode 100644 index 0000000000..bc0d956bb7 --- /dev/null +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -0,0 +1,3797 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include + +#include +#include +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_8, _2, _2, _2>>, + Stride,Stride<_1,_128,_64,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_8, _2, _2, _4>>, + Stride,Stride<_1,_128,_64,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2>>, + Stride,Stride<_1,_128,_64>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_128>, + Stride, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2, _2>>, + Stride,Stride<_1,_128,_64,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_128>, + Stride, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2, _4>>, + Stride,Stride<_1,_128,_64,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_128, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +namespace TMEM { + using MAX_CAPACITY_BITS = Int<128*512*32>; // 128 DP x 512 COL x uint32_t-addressing + + template // TMEM DP stride in type-T addressing + using DP = cute::constant::OffsetShift)>; + + using DP_b = cute::constant; // TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) +} + +// TMEM_LOAD copy_unpack +template +struct TMEM_LOAD_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_tmem::value, "Expected TMEM src."); + static_assert(is_rmem::value, "Expected RMEM dst."); + + using SrcType = typename TS::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected src to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(src.data()); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}); + } +}; + +// TMEM_STORE copy_unpack +template +struct TMEM_STORE_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected RMEM src."); + static_assert(is_tmem::value, "Expected TMEM dst."); + + using RegTypeSrc = typename remove_extent::type; + Tensor rS = recast(src); + + constexpr int RegNumSrc = extent::value; + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + + using DstType = typename TD::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected dst to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(dst.data()); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + rS, make_seq{}, + &tmem_addr, seq<0>{}); + } +}; + +template +struct Copy_Atom; + +/** Generate a TiledCopy from a CopyAtom and a TMEM tensor + * Example: + * Tensor gmem_tensor = ... // (M,N,...) + * Tensor tmem_tensor = ... // (M,N,...) + * auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD_Operation, tmem_tensor); + * auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + * + * Tensor tDtC = thr_tmem_load.partition_S(tmem_tensor); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * Tensor tDgC = thr_tmem_load.partition_D(gmem_tensor); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * Tensor tDrC = make_tensor(shape(tDgD)); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * + * copy(tiled_tmem_load, tDtC, tDrC); // tmem -> rmem + * copy(tDrC, tDgC); // rmem -> gmem + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_copy(Copy_Atom const& atom, + Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + using Traits = typename Copy_Atom::Traits; + static_assert(sizeof_bits_v == sizeof_bits_v, + "Expected a CopyAtom with the same type-width as the Tensor."); + + // atom thr idx -> tmem addr 4warps where each warp points to the same position within it's own subpartition + auto atom_t_layout = Layout, Stride<_0, decltype(Int<32>{} * TMEM::DP{})>>{}; + // atom val idx -> tmem addr Cast the CopyOp's value ids to the proper data width + auto atom_v_layout = coalesce(upcast::value>(typename Traits::ValID{})); + + return make_cotiled_copy(atom, make_layout(atom_t_layout, atom_v_layout), tmem.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_copy(CopyOp const&, + Tensor const& tmem) +{ + return make_tmem_copy(Copy_Atom{}, tmem); +} + +/** Generate a TV_Tiler from a TMEM tensor + * Example: + * Tensor gmem_tensor = ... // (M,N,...) + * Tensor tmem_tensor = ... // (M,N,...) + * auto tmem_tiler = make_tmem_warp_partitioner(tmem_tensor); + * auto warp_tiler = tmem_tiler.get_slice(warp_idx); + * + * Tensor tWtC = warp_tiler.partition(tmem_tensor); // (WARP_M,WARP_N,...) + * Tensor tWgC = warp_tiler.partition(gmem_tensor); // (WARP_M,WARP_N,...) + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_warp_partitioner(Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + + // warp idx -> tmem addr This is the T in the Layout_TV + auto atom_t_layout = Layout<_4, decltype(Int<32>{} * TMEM::DP{})>{}; + + // tmem coord -> tmem addr + auto tmem_layout = tmem.layout(); + // tmem addr -> tmem coord Append 1:0 so off-the-ends get the stride-0 + auto inv_tmem_layout = make_layout(left_inverse(tmem_layout), Layout<_1,_0>{}); + + // wid -> tmem_coord + auto layout_t_tmem = composition(inv_tmem_layout, atom_t_layout); + +#if 0 + if (thread0()) { + print("input : "); print(tmem.data()); print(" o "); print(tmem_layout); print("\n"); + print("atom_t_layout : "); print(atom_t_layout); print("\n"); + print("layout_tv_tmem : "); print(layout_tv_tmem); print("\n"); + } +#endif + + // + // Tiler -- Find the active elements in the TMEM tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled TMEM + auto flat_tmem_shape = product_each(shape(tmem_layout)); + auto flat_tmem_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_tmem_shape, replace(flat_tmem_zeros, Int<1>{})), layout_t_tmem)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> tmem_coord + auto tile2tmem = composition(make_layout(flat_tmem_shape), tiler); + + // wid -> tile_coord + auto layout_tv = composition(left_inverse(tile2tmem), layout_t_tmem); + +#if 0 + if (thread0()) { + print("tiler : "); print(tiler); print("\n"); + print("tile2tmem : "); print(tile2tmem); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + } +#endif + + return make_tiler_impl(layout_tv, tiler); +} + +} // end namespace cute +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMEM { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Given a 1x tmem copy op, returns the widest repeated variant that divides the specified bits in the N-mode +template +CUTE_HOST_DEVICE constexpr +auto +op_repeater() +{ + if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + } + else { + static_assert(dependent_false, "Must pass 1x tmem copy operator"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Select TMEM store corresponding to the provided TMEM load +template +CUTE_HOST_DEVICE constexpr auto +tmem_load_to_store(CopyOp) { + if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else { + static_assert(dependent_false, "No TMEM_STORE matching for provided TMEM_LOAD"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace TMEM + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_LOAD Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + // Logical bit id to bit idx (address) + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct Copy_Traits + : TMEM_LOAD_Unpack +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_STORE Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Copy_Traits + : TMEM_STORE_Unpack +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTCCP Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// In the following UTCCP traits, the ValID is representing: +// logical_bit_idx -> tmem_addr_offset. +// And the logical_bit_idx is numbered in the order of: +// [core_matrix_strided, core_matrix_leading, broadcast, repeat]. +// The first two modes provide convenience for smem_desc construtction. +// The last two modes provide boradcast transformation for 4x32DP and 2x64DP. +// With above, the strides of first two modes are neccessary to be TMEM::DP_b and 1. +// And the stride of the third mode in the SrcLayout must be zero. +template <> +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + using ValID = Layout, + Stride>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_128dp256bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // logical bit_idx -> tmem_addr + using ValID = typename Copy_Traits::ValID; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_128dp256bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + using ValID = Layout, + Stride>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_128dp128bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // logical bit_idx -> tmem_addr + using ValID = typename Copy_Traits::ValID; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_128dp128bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + /* + 4DP is really hard to model if we consider this instruction as a "copy" instruction. + But, if we take it as "TMEM refresh" instruction, then everything goes out naturally. + 4DP utccp is designed to refresh the last 4 lanes of each tmem subpartition. + So, in the kernel implementation, we usually only don't need to iterate on MMA_M dimension, + but only need to iterate on MMA_K dimension. + And in each refresh, logically we are refreshing MMA's 128 rows M + 256bit K. + So the "atom_v" should be (refresh_m, refresh_k) instead of (copy_m, copy_k). + And the Src/DstLayout below is: copy_bits -> logical_refresh_bits. + */ + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + using ValID = Layout, + Stride>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_4dp256bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + + using ThrID = Layout<_2>; + // logical bit_idx -> tmem_addr + using ValID = typename Copy_Traits::ValID; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_4dp256bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx32 = Int<_DP{}*32>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, broadcast] + using ValID = Layout, + Stride<_DP,_1, _DPx32>>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _0>>>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_4x32dp128bit_1cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + + using ThrID = Layout<_2>; + // logical bit_idx -> tmem_addr + using ValID = typename Copy_Traits::ValID; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _0>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_4x32dp128bit_2cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx64 = Int<_DP{}*64>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, broadcast] + using ValID = Layout, + Stride<_DP,_1, _DPx64>>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _64, _0>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_2x64dp128bitlw0213_1cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + + using ThrID = Layout<_2>; + // logical bit_idx -> tmem_addr + using ValID = typename Copy_Traits::ValID; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _64, _0>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_2x64dp128bitlw0213_2cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx32 = Int<_DP{}*32>; + using _DPx64 = Int<_DP{}*64>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, repeat, broadcast] + using ValID = Layout, + Stride<_DP,_1 ,_DPx64,_DPx32>>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32,_4096,_0>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_2x64dp128bitlw0123_1cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template <> +struct Copy_Traits +{ + + using ThrID = Layout<_2>; + // logical bit_idx -> tmem_addr + using ValID = typename Copy_Traits::ValID; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _4096,_0>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + SM100_UTCCP_2x64dp128bitlw0123_2cta::copy(src[0], raw_pointer_cast(dst.data())); + } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_utccp_copy(CopyOp const&, + Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + using Traits = Copy_Traits; + using Atom = Copy_Atom; + + // atom thr idx -> tmem addr This is the T in the Layout_TV + auto atom_t_layout = make_layout(size(typename Traits::ThrID{}), Int<0>{}); + // atom val idx -> tmem addr Cast the CopyOp's value ids to the proper data width + auto atom_v_layout = coalesce(upcast::value>(typename Traits::ValID{})); + + return make_cotiled_copy(Atom{}, make_layout(atom_t_layout, atom_v_layout), tmem.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_sm100_im2col.hpp b/include/cute/atom/copy_traits_sm100_im2col.hpp new file mode 100644 index 0000000000..cd3bf98bbe --- /dev/null +++ b/include/cute/atom/copy_traits_sm100_im2col.hpp @@ -0,0 +1,488 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy + +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/atom/copy_traits_sm90_im2col.hpp" +#include "cute/tensor.hpp" + +namespace cute { + +struct SM100_TMA_2SM_LOAD_IM2COL_OP : SM100_TMA_2SM_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM100 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for consistency with the actual multicast Copy_Traits + /// specialization) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP : SM100_TMA_2SM_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM100 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const + { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM100_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM100 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +//////////////////////////////////// +// Make TMA +/////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with. + * This is often the cluster_tile_shape that is used to tile the GMEM: + * local_tile(gtensor, cluster_tile_shape, cluster_coord) + * -> Cluster-local tile of GMEM + * @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning. + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + */ +template +CUTE_HOST +auto +make_im2col_tma_copy_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M,K,...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K) + Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K) + TiledMMA const& mma, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only MK modes from MNK + auto cluster_tile_shape = append(make_shape(get<0>(cluster_tile), get<2>(cluster_tile)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat(_)); + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)); + + return detail::make_tma_copy_im2col(copy_op, gtensor, slayout, + cta_t_map, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, + lower_srt, stride_srt, aux_params); +} + +template +CUTE_HOST +auto +make_im2col_tma_copy_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N,K,...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K) + Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K) + TiledMMA const& mma, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only NK modes from MNK + auto cluster_tile_shape = append(make_shape(get<1>(cluster_tile), get<2>(cluster_tile)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat(_)); + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)); + + return detail::make_tma_copy_im2col(copy_op, gtensor, slayout, + cta_t_map, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, + lower_srt, stride_srt, aux_params); +} + +///////////////////////////////////// +// Experimental Make Im2col TMA Atom +///////////////////////////////////// + +template +CUTE_HOST +auto +make_im2col_tma_atom_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K) + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only MK modes from MNK + auto cluster_tile_shape = append(make_shape(get<0>(mma_tiler), get<2>(mma_tiler)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat(_)); + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, + stride_whd, lower_srt, stride_srt, aux_params); +} + +template +CUTE_HOST +auto +make_im2col_tma_atom_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K) + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only NK modes from MNK + auto cluster_tile_shape = append(make_shape(get<1>(mma_tiler), get<2>(mma_tiler)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat(_)); + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, + stride_whd, lower_srt, stride_srt, aux_params); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm100_tma.hpp b/include/cute/atom/copy_traits_sm100_tma.hpp new file mode 100644 index 0000000000..851db2891d --- /dev/null +++ b/include/cute/atom/copy_traits_sm100_tma.hpp @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2021 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////// TMA_LOAD //////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_OP : SM100_TMA_2SM_LOAD {}; + +// The non-executable SM100_TMA_2SM_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, static_cast(cache_hint)}}; + } + + // Construct an executable SM100_TMA_2SM_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, static_cast(cache_hint)}}; + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM100_TMA_2SM_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_MULTICAST_OP : SM100_TMA_2SM_LOAD_MULTICAST {}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_MULTICAST_OP arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + } + + // Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD_MULTICAST_OP before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_MULTICAST_OP arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; +}; + +//////////////////////////////////// +// Make TMA +/////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with. + * This is often the cluster_tile_shape that is used to tile the GMEM: + * local_tile(gtensor, cluster_tile_shape, cluster_coord) + * -> Cluster-local tile of GMEM + * @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning. + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + */ +template +CUTE_HOST +auto +make_tma_copy_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + Cluster_Tiler const& cluster_tiler, // (TILER_M, TILER_N, TILER_K, ...) + TiledMMA const& mma) +{ + // Keep only MK modes from MNK + auto cluster_tiler_mk = remove<1>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mk); // (TILE_M, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_K, ...) + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides))); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_copy_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma) +{ + // Keep only NK modes from MNK + auto cluster_tiler_nk = remove<0>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_nk); // (TILE_N, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat(_)); // (MMA, MMA_N, MMA_K, ...) + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides))); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_copy_C_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, N, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_N, ...) + Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma) +{ + // Keep only MN modes from MNK + auto cluster_tiler_mn = remove<2>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mn); // (TILE_M, TILE_N, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_C(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_N, ...) + + static_assert(is_same_v || + is_same_v || + is_same_v, + "Unsupported TMA Op, expected a non-multicast TMA"); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +//////////////////////////////////// +// Experimental Make TMA Atom +/////////////////////////////////// + +template +CUTE_HOST +auto +make_tma_atom_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K) +{ + // Keep only MK modes from MNK + auto mma_tiler_mk = remove<1>(mma_tiler); + + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_mk); // (TILE_M, TILE_K, ...) + + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_K, ...) + +#if 0 + print("(tma_a) slayout: "); print(slayout); print("\n"); + print("(tma_a) mma_tiler_nk: "); print(mma_tiler_nk); print("\n"); + print("(tma_a) g_tile: "); print(g_tile); print("\n"); + print("(tma_a) mma_tiler: "); print(mma_tiler); print("\n"); + print("(tma_a) cta_v_tile: "); print(cta_v_tile); print("\n"); +#endif + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, gtensor, slayout, num_multicast, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_atom_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K) +{ + // Keep only NK modes from MNK + auto mma_tiler_nk = remove<0>(mma_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_nk); // (TILE_N, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat(_)); // (MMA, MMA_N, MMA_K, ...) + +#if 0 + print("(tma_b) slayout: "); print(slayout); print("\n"); + print("(tma_b) mma_tiler_nk: "); print(mma_tiler_nk); print("\n"); + print("(tma_b) g_tile: "); print(g_tile); print("\n"); + print("(tma_b) mma_tiler: "); print(mma_tiler); print("\n"); + print("(tma_b) cta_v_tile: "); print(cta_v_tile); print("\n"); +#endif + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, gtensor, slayout, num_multicast, cta_v_tile); +} + +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp index eaf3c020fb..9a447896cc 100644 --- a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -56,6 +56,13 @@ get_tma_swizzle_bits(Swizzle) case 0: return TMA::SmemSwizzleBits::DISABLE; } } else + + if constexpr (M == 5 || M == 6) { + static_assert(B == 2, "Expected B = 2 when M == 5 or 6. Unsupported layout swizzle."); + // S-condition as well? + return TMA::SmemSwizzleBits::B128; + } else + { static_assert(M < 0, "Unsupported layout swizzle."); } @@ -78,9 +85,25 @@ get_tma_swizzle_base(Swizzle) static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B; } + + else if constexpr (M == 5) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_32B; + } else if constexpr (M == 6) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_64B; + } + #if 1 + else { + static_assert(4 <= M && M <= 6, "Expected 128b=16B=(2^4)B to 512b=64B=(2^6)B base swizzle."); + } + #else + else { static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); } + #endif } template diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 957f070771..35a5c8a1c9 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -154,6 +154,10 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeA is a view (of the current tensor), forward the whole static_assert(is_same::value_type>::value + + || (sizeof_bits_v::value_type> == 8 && + (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + , "Expecting ValTypeA type"); return make_tensor(static_cast(atensor)); } else { @@ -176,6 +180,10 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeB is a view (of the current tensor), forward the whole static_assert(is_same::value_type>::value + + || (sizeof_bits_v::value_type> == 8 && + (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + , "Expecting ValTypeB type"); return make_tensor(static_cast(btensor)); } else { @@ -1109,4 +1117,5 @@ print_svg(TiledMMA const &mma) { #include #include #include +#include //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp new file mode 100644 index 0000000000..71a9dd2a3b --- /dev/null +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -0,0 +1,2425 @@ +/*************************************************************************************************** + * Copyright (c) 2022 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include +#include +#include +#include +#include +#include +#include // cute::TMEM:: +#include +#include // cute::GMMA:: +#include // cute::GMMA:: +#include + +// Check that aggregate initialization in .with() initializes all fields +#if defined(__GNUG__) +#pragma GCC diagnostic warning "-Wmissing-field-initializers" +#pragma GCC diagnostic error "-Wmissing-field-initializers" +#endif + +namespace cute { + +namespace UMMA { + +////////////////////////////////////////////////// +// Common layouts for UMMA Shared Memory // +////////////////////////////////////////////////// + +// TODO: Extend for remaining sm100 new layouts +using cute::GMMA::Layout_MN_INTER_Atom; +using cute::GMMA::Layout_MN_SW32_Atom; +using cute::GMMA::Layout_MN_SW64_Atom; +using cute::GMMA::Layout_MN_SW128_Atom; +using cute::GMMA::Layout_K_INTER_Atom; +using cute::GMMA::Layout_K_SW32_Atom; +using cute::GMMA::Layout_K_SW64_Atom; +using cute::GMMA::Layout_K_SW128_Atom; + +using Layout_MN_SW128_32B_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _1024>>>; + +template +using Layout_MN_SW128_32B_Atom = decltype(upcast::value>(Layout_MN_SW128_32B_Atom_Bits{})); + +// Tile a MN-logical layout atom to an MMA Tile Shape ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_mma_shape(LayoutAtom const& atom, MMATileShape const& mma_tile_shape, ModeOrder const& order = {}) +{ + constexpr int R = decltype(rank(mma_tile_shape))::value; + auto mn_shape = cute::tuple_cat(zip(shape<0>(mma_tile_shape), take<1,3>(mma_tile_shape)), take<3,R>(mma_tile_shape)); + auto mn_tiled = tile_to_shape(atom, mn_shape, order); // (BLK_M,BLK_N,...) + return tiled_divide(mn_tiled, product_each(shape<0>(mma_tile_shape))); // ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) +} + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + if constexpr (M == 4) { + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 0: return LayoutType::SWIZZLE_NONE; + case 1: return LayoutType::SWIZZLE_32B; + case 2: return LayoutType::SWIZZLE_64B; + case 3: return LayoutType::SWIZZLE_128B; + } + } else + if constexpr (M == 5) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle."); + return LayoutType::SWIZZLE_128B_BASE32B; + } else { + static_assert(M==5, "Only 16B and 32B Atoms are supported for UMMA. Unsupported layout swizzle."); + return LayoutType::SWIZZLE_NONE; // ERROR + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for UMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_umma_desc // +* /////////////////////////////// +* Each UmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* LayoutType::128B_BASE32B : Swizzle<2,5,2> o smem_ptr o ((T,8,m),(4,k)):((1,T,LBO),(?T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to UMMA shape +* k : integer in [1,32] corresponding to UMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See UMMA::Layout_MN_XXX_Atom for building canonical UmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_umma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_umma_desc // +* ////////////////////////////// +* Each UmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See UMMA::Layout_K_XXX_Atom for building canonical UmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_umma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +SmemDescriptor +make_umma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "UMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "UMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + SmemDescriptor desc; + desc.version_ = 1; // Set the version for blackwell + desc.lbo_mode_ = 0; // set to legacy mode by default + + // Layout type + constexpr UMMA::LayoutType LAYOUT_TYPE = UMMA::layout_type(u128_tensor); + desc.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.start_address_ = static_cast(start_address >> 4); + + constexpr uint8_t base_offset = 0; + desc.base_offset_ = base_offset; + + // LayoutType meta + constexpr int SwizzleAtomMNSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? 1 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_32B ? 2 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_64B ? 4 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B ? 8 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 8 : -1; + + if constexpr (MajorMode == UMMA::Major::MN) + { + /* In units of uint128_t, each UmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + * LayoutType::B128_BASE32B : Swizzle<2,5,2> o smem_ptr o ((8,n),(4,k)):((1,LBO),(4,SBO)) + */ + + constexpr int SwizzleAtomKSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 4 : 8; + + // Construct the canonical UMMA T Layout with shape + // ((SwizzleAtomMNSize,n),(SwizzleAtomKSize,2)) + Layout canonical_layout = + logical_divide(layout(u128_tensor), + make_tile(Layout, _1>{}, + Layout, _1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_MN Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_MN Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical UMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = SwizzleAtomMNSize; + static_assert(stride_10 == expected_stride_10, "Not a canonical UMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.stride_byte_offset_ = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride_01 : stride_11; + desc.leading_byte_offset_ = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride_11 : stride_01; + } else + if constexpr (MajorMode == UMMA::Major::K) + { + /* In units of uint128_t, each UmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + * LayoutType::B128_BASE32B : Not applicable for Major-K + */ + + static_assert(LAYOUT_TYPE != UMMA::LayoutType::SWIZZLE_128B_BASE32B, "SWIZZLE_128B_BASE32B is invalid for Major-K"); + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical UMMA_K Layout: Expected MN-size multiple of 8."); + + // Construct the canonical UMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_K Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_K Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = SwizzleAtomMNSize; + static_assert(stride_00 == expected_stride_00, "Not a canonical UMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical UMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.stride_byte_offset_ = stride_01; + desc.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != UMMA::Major::MN && MajorMode != UMMA::Major::K, "Unrecognized MajorMode!"); + } + +#if 0 + // DEBUG and SANITY + assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation + assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later + if (thread0()) { + print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); + print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); + //print(" desc canonical layout: "); print(canonical_layout); print("\n"); + print(desc); + } +#endif + + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level UMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = SmemDescriptor; + using element_type = SmemDescriptor; + using value_type = SmemDescriptor; + + SmemDescriptor desc_; + + // Dereference returns the UmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new UmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc + SmemDescriptor ret; + ret.lo = desc_.lo + uint32_t(offset); + ret.hi = desc_.hi; + return { ret }; + } +}; + +template +CUTE_HOST_DEVICE constexpr +SmemDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +CUTE_HOST_DEVICE void +print(DescriptorIterator const&) { + printf("UMMA::DescriptorIterator"); +} + +// Flag for smem descriptor allocation/creation +template +struct smem_desc : DescriptorIterator {}; + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace UMMA + +// Customization point for creating a UMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a UMMA Desc Tensor"); + return make_tensor(UMMA::DescriptorIterator{UMMA::make_umma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +// Customization point for creating a UMMA::sparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as UMMA::smem_desc above. + // Only the interface validates that we are passed a sparse_ptr, which is recast away to construct + // the smem desc tensor + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a UMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(UMMA::DescriptorIterator{UMMA::make_umma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +// Special smem_desc_iter tensor entry for UTCCP copy. +template +constexpr auto get_utccp_smem_desc_tensor(Tensor const& smem_utccp_partitioned_tensor) { + using VecLayout = decltype(layout<0>(TLayout{})); + static_assert(VecLayout::rank == 2 && shape<1>(VecLayout{}) == 1, "Mismatched vec_mode tensor."); + static_assert(is_smem::value, "Expect vec_mode smem_tesnor."); + static_assert(is_static::value, "Utccp copy tensor's vec_mode should be static."); + + using value_type = typename TEngine::value_type; + using UtccpTaits = Copy_Traits; + + // UtccpTaits::ValID: logical_bit_idx -> tmem_offset. + // We arrange the logical_bit_idx in order of (core_matrix_strided, core_matrix_leading, repeat(only in 64dplw01), broadcast). + // So we only need the first two modes for src smem_tensor. + auto utccp_core_matrix_shape = take<0,2>(upcast>(typename UtccpTaits::ValID{}).shape()); + // logical_bit_idx -> smem_addr + Layout vec_v_layout = flatten(layout<0>(VecLayout{})); + Layout utccp_core_matrix_layout = vec_v_layout.with_shape(utccp_core_matrix_shape); + Tensor utccp_core_matrix_tensor = group_modes<0,2>(make_tensor(smem_utccp_partitioned_tensor.data(), utccp_core_matrix_layout)); + Tensor core_matrix_desc_tensor = make_tensor>(utccp_core_matrix_tensor); + return make_tensor(core_matrix_desc_tensor.data(), recast_layout(smem_utccp_partitioned_tensor.layout())); +} + +namespace UMMA { + +enum class TmemAllocMode { + // Default allocation mode. + // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be + // interleaved by using the top-half-subpartition and the bottom-half-subpartition. + // Full utilization of TMEM capacity. + Interleaved = 0, + // Prevents interleaving. + // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms will not be + // interleaved. + // Required for DP-address equivalence in TMEM-A and TMEM-C allocations in UMMA_TS. + NonInterleaved = 1, + // Duplicates the TMEM allocation across subpartitions. + // E.g. UMMA_2SM_128xNx16_TS uses a "2x2 DP" TMEM Layout, but the TMEM allocation is + // actually doubled and the input data must be duplicated between the + // subpartitions [0,1]<->[2,3], i.e., each subpartition holds all columns + // of the A matrix needed for a single UMMA operation. + // For UMMA_2SM_128xNx16_TS, the distribution of the data is as follows. + // SM0: + // Subpart0 = A[0:32, 0:16], Subpart1 = A[32:64, 0:16], + // Subpart2 = A[A:32, 0:16], Subpart3 = A[32:64, 0:16] + // SM1: + // Subpart0 = A[64:96, 0:16], Subpart1 = A[96:128, 0:16], + // Subpart2 = A[64:96, 0:16], Subpart3 = A[96:128, 0:16] + Duplicated = 2, + // Duplicates the TMEM allocation across subpartitions for scale factor. + // Scale factor TMEM allocation for 4x1 data path + ScaleFactorDuplicated4by1 = 3, + // Scale factor TMEM allocation for 2x2 data path + ScaleFactorDuplicated2by2 = 4 +}; + +struct tmem_frg_base {}; + +// The UMMA Traits below have custom fragment type flags for their tmem tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct tmem_frg : tmem_frg_base +{ + static_assert(sizeof_bits_v <= sizeof_bits_v, "TMEM MMA allocations require StorageType big enough for ValueType."); + + // UMMA TMEM Allocator + // Each UMMA expects a specific MxN layout of TMEM for accumulators + // and sometimes a specific MxK layout of TMEM for A-values. + // @tparam ValueType The value type of the TMEM Tensor to allocate. + // @tparam StorageType The storage type of the TMEM Tensor to allocate. + // "Sparse" allocations often allocate ValueType=half_t within StorageType=uint32_t. + // "Dense" allocations often allocate ValueType=half_t within StorageType=half_t. + // @tparam N_SM The number of SMs in this UMMA_XSM instruction. + // @tparam TmemAlloc UMMA-specific allocation modifier for special cases. + // Some UMMA instructions expect strange atoms or tilings of atoms. + // @param tmem_shape ((M_MMA_SM,N_MMA_SM),MMA_M,MMA_N,...) + // The post-MMA-partitioned shape of TMEM to allocate. + // Note for UMMA_2SM_128xNx16, that M_MMA_SM will be 64, for example. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(size(tmem_shape)*Int)>{} <= TMEM::MAX_CAPACITY_BITS{}, + "Requesting more TMEM than is available."); + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + static_assert(N_SM == 1 || N_SM == 2, "UMMA expects N_SM == 1 or N_SM == 2"); + if constexpr (N_SM == 1) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved || TmemAlloc == UMMA::TmemAllocMode::NonInterleaved, + "UMMA_1SM only accepts Interleaved or NonInterleaved"); + static_assert(M_MMA == 64 || M_MMA == 128, "UMMA_1SM M-mode size should be 64 or 128."); + + if constexpr (M_MMA == 64) + { + // Half subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, Int>, + Stride, _128>>{}; + // tile_stride = 2 causes the tiling to "skip" the first tile in DPs + constexpr int tile_stride = TmemAlloc == UMMA::TmemAllocMode::Interleaved ? 1 : 2; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape),Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // Full subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + + } else + if constexpr (N_SM == 2) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved || TmemAlloc == UMMA::TmemAllocMode::Duplicated, + "UMMA_2SM only accepts Interleaved or Duplicated"); + static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, "UMMA_2SM M-mode size should be 32 or 64 or 128."); + + if constexpr (M_MMA == 32) // TODO: Implement Duplicated mode for M_MMA = 32 + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved, "Only TmemAllocMode::Interleaved is supported for UMMA_2SM M_MMA=32"); + // The "1x4" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, _4>>, + Stride< _1,Stride< _128,_32>>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 64 && TmemAlloc == UMMA::TmemAllocMode::Interleaved) + { + // The "2x2" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, _2>>, + Stride< _1,Stride< _128,_64>>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + + } else + if constexpr (M_MMA == 64 && TmemAlloc == UMMA::TmemAllocMode::Duplicated) + { + // The "2x2" duplicated layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // The "4x1" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// Convenient aliases for common cases in the UMMA::ElementXFrg below +template +using tmem_frg_1sm = tmem_frg; +template +using tmem_frg_2sm = tmem_frg; + +// Make metadata TMEM fragments for sparse MMAs. +// Also note that the TMEM fragment addresses are assumed to be COL-4 aligned -- working with arch to remove this condition +template +struct tmem_e_frg : tmem_frg_base +{ + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + static_assert(M_MMA == 128, "Only 128 implemented right now."); + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + [[maybe_unused]] Layout tmem_restride = Layout, + Stride>{}; + + if constexpr (sizeof_bits::value == 32) // TF32: 128x16 atom + { + static_assert(N_MMA == 16); + Layout tmem_atom = Layout, Shape < _8,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } else + if constexpr (sizeof_bits::value == 16) // FP16: 128x32 atom + { + static_assert(N_MMA == 32); + Layout tmem_atom = Layout, Shape < _16,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } else + if constexpr (sizeof_bits::value == 8) // S8|Mix.F4/F6/F8: 128x64 atom + { + // For Mix 8bit f4/f6/f8, will pass in ValueType = uint8_t + static_assert(N_MMA == 64); + Layout tmem_atom = Layout, + Stride< _1,_128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + if constexpr (sizeof_bits::value == 4) // F4: 128x128 atom + { + // For F4, will pass in ValueType = fp4 + Layout tmem_restride1 = Layout>, + Stride, _1>>{}; + // F4 has roughly same TMEM layout as Mix8bit.F4/F6/F8, the only difference is that K is multiplied by two + static_assert(N_MMA == 128); + Layout tmem_atom = Layout, + Stride< _1, _128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride1, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +template +struct tmem_e_frg_ws : tmem_frg_base +{ + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + static_assert(M_MMA == 128 || M_MMA == 64 || M_MMA == 32, "Weight stationary UMMA_1SM M-mode size should be 32 or 64 or 128."); + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + Layout tmem_restride = Layout, + Stride>{}; + + if constexpr (sizeof_bits::value == 32) // TF32 + { + // MMA_M x MMA_K: 128x16 atom / 64x16 atom / 32x16 atom + static_assert(N_MMA == 16); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, Shape < _8,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout, Shape < _8,_2>, _2>, + Stride, Stride<_128,_8>,_64>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles its own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout, Shape < _8,_2>, _4>, + Stride, Stride<_128,_8>,_32>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles its own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else if constexpr (sizeof_bits::value == 16) // FP16 + { + // MMA_M x MMA_K: 128x32 atom / 64x32 atom / 32x32 atom + static_assert(N_MMA == 32); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, Shape < _16,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout, Shape < _16,_2>, _2>, + Stride, Stride<_128,_8>,_64>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout, Shape < _16,_2>, _4>, + Stride, Stride<_128,_8>,_32>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else if constexpr (sizeof_bits::value == 8) // I8|F8 + { + // MMA_M x MMA_K: 128x64 atom / 64x64 atom / 32x64 atom + static_assert(N_MMA == 64); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, + Stride< _1,_128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout>, + Stride< _1, Stride<_128, _64>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout>, + Stride< _1, Stride<_128, _32>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else { + static_assert(dependent_false, "Invalid ValueType"); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +template +struct tmem_sf_frg: tmem_frg_base +{ + // UMMA TMEM Allocator for Scale Factor A for Mxf4Nvf4 and Mxf8f6f4 instructions + // We expect a tensor that has the same layout as A matrix + // @tparam ValueType: data type of scaling factor + // Note that the StorageType is the same as ValueType, i.e., we always use a compact allocation + // @tparam SFVecSize: The number of values that is scaled by a single scaling factor. + // Valid values are (16, 32) + // @tparam N_SM: Number of SMs in UMMA instruction + // @param tmem_shape: An MMA partitioned shape where first mode encodes, A layout of the MMA instruction. + // Note that the shape doesn't match the actual allocation. size<0,1>(tmem_shape) will give us the number of + // elements in K-mode of MMA rather than the number of scaling factors. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int MMA_MN = decltype(size<0,0>(tmem_shape))::value; + constexpr int MMA_VS = decltype(size<0,1,0>(tmem_shape))::value; + constexpr int MMA_NSF = decltype(size<0,1,1>(tmem_shape))::value; + constexpr int R_MMA_K = decltype(rank(get<0,1>(tmem_shape)))::value; + constexpr int R = decltype(rank(tmem_shape))::value; + + // We expect an MMA-SF partitioned tensor + // ((MMA_MN, (VecSize, NSF)), num_MMA_MN, num_MMA_K, ...) + // where VecSize*NSF = MMA_K + static_assert(R >= 3, "Expected an MMA partitioned tensor"); // ((MMA), num_MMA_MN, num_MMA_K, ...) + static_assert(R_MMA_K == 2, "Expected an MMA-SF partitioned tensor"); // (VecSize, NSF) + using REP = _4; // Replication factor. Data is always replicated across subpartitions + constexpr int SUBPART_DPs = 32; // Number of DPs in a subpartition + + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + if constexpr (Is_SFA || (!Is_SFA && TmemAlloc == UMMA::TmemAllocMode::ScaleFactorDuplicated4by1)) { + // SFA, 2x2 and 4x1 data path + // SFB, 4x1 data path + auto tmem_atom = Layout < Shape< Shape< Shape, Int>, REP>, Shape, Int>>, + Stride, _32>, Stride< _0, _128>>>{}; + + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + auto final_tmem_layout = composition(tmem_restride, tmem_logical_layout); + return make_tensor(make_tmem_ptr(), final_tmem_layout); + } + else { + // SFB, 2x2 datapath + static_assert(!Is_SFA and TmemAlloc == UMMA::TmemAllocMode::ScaleFactorDuplicated2by2); + static_assert(N_SM == 2, "Should be 2x2 Datapath"); + // 2x2 Datapth + auto tmem_atom = Layout < Shape< Shape< Shape, Int>, _2, _2>, Shape, Int>>, + Stride, _64, _32>, Stride< _0, _128>>>{}; + + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + auto final_tmem_layout = composition(tmem_restride, tmem_logical_layout); + return make_tensor(make_tmem_ptr(), final_tmem_layout); + } + } +}; + +// Make C/D Tmem fragment for weight-stationary MMAs +template +struct tmem_frg_ws : tmem_frg_base +{ + static_assert(sizeof_bits_v <= sizeof_bits_v, "TMEM MMA allocations require StorageType big enough for ValueType."); + + // UMMA TMEM Allocator + // Each UMMA expects a specific MxN layout of TMEM for accumulators + // and sometimes a specific MxK layout of TMEM for A-values. + // @tparam ValueType The value type of the TMEM Tensor to allocate. + // @tparam StorageType The storage type of the TMEM Tensor to allocate. + // "Sparse" allocations often allocate ValueType=half_t within StorageType=uint32_t. + // "Dense" allocations often allocate ValueType=half_t within StorageType=half_t. + // @tparam N_SM The number of SMs in this UMMA_XSM instruction. + // @tparam TmemAlloc UMMA-specific allocation modifier for special cases. + // Some UMMA instructions expect strange atoms or tilings of atoms. + // @param tmem_shape ((M_MMA_SM,N_MMA_SM),MMA_M,MMA_N,...) + // The post-MMA-partitioned shape of TMEM to allocate. + // Note for UMMA_2SM_128xNx16, that M_MMA_SM will be 64, for example. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(size(tmem_shape)*Int)>{} <= TMEM::MAX_CAPACITY_BITS{}, + "Requesting more TMEM than is available."); + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + static_assert(N_SM == 1, "UMMA.WS expects N_SM == 1"); + + static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, + "Weight stationary UMMA_1SM M-mode size should be 32 or 64 or 128."); + static_assert(N_MMA == 64 || N_MMA == 128 || N_MMA == 256, + "Dense weight stationary UMMA_1SM N-mode size should be 64 or 128 or 256."); + // Weight Stationary MMA config + if constexpr (M_MMA == 32) + { + // 1x4 datapath + Layout tmem_atom = Layout, _4>>, + Stride< _1, Stride< _128,_32>> + >{}; + constexpr int tile_stride = 1; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape), Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 64) + { + // 2x2 datapath + Layout tmem_atom = Layout, _2>>, + Stride< _1, Stride< _128,_64>> + >{}; + constexpr int tile_stride = 1; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape), Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // Full subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// Convenient aliases for common cases in the UMMA::ElementXFrg below +template +using tmem_frg_ws_1sm = tmem_frg_ws; + +} // end namespace UMMA + +// Customization point for creating a UMMA::tmem_frg Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_frg::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_frg_ws::make(shape(tmem_shape)); + } +}; + + +// Customization point for creating a UMMA::tmem_frg Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_e_frg::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_e_frg_ws::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_sf_frg::make(shape(tmem_shape)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 8 between 8 and 256."); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4 supports types with leq 8bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + constexpr static int SFVecSize = 32; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 32; + + static_assert(M == 128 || M == 256, "MMA_F8F6F4 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA_F8F6F4 N-mode size should be a multiple of 16 between 16 and 256."); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + constexpr static int SFVecSize = 32; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS 64 ? M/2 : M), (N == 192 ? 256 : N), a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4 supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 32 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 16), + "2x mode (VectorSize=32) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4 supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (N == 192 ? 256 : N), VS, a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + + +} // end namespace cute diff --git a/include/cute/atom/partitioner.hpp b/include/cute/atom/partitioner.hpp new file mode 100644 index 0000000000..75a55ccf6b --- /dev/null +++ b/include/cute/atom/partitioner.hpp @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include + +namespace cute { + +// +// A generic tiling of thread-value layouts +// + +template coord [Need not be 2D...] + class Tiler_MN_> // coord space +struct TV_Tiler +{ + using Tiler_MN = Tiler_MN_; + using TiledLayout_TV = Layout_TV_; + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,FrgV),(RestM,RestN,...)) + // where + // ThrV: The threads local to a tile. + // FrgV: The values local to a tile. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + apply(Tensor&& tensor) + { + // If Layout_TV and Tiler_MN were composable in general, then this won't be needed! + + // ((thr_id,val_id),(RestM,RestN,...)) + return zipped_divide(tensor, Tiler_MN{}).compose(TiledLayout_TV{}, _); + } + + template + struct TV_Partitioner + { + SliceCoord coord_; + + template + CUTE_HOST_DEVICE + auto + partition(TargetTensor&& target) { + Tensor thr_tensor = make_tensor(static_cast(target).data(), apply(target.layout())); + return thr_tensor(coord_, repeat>(_)); + } + }; + + template + CUTE_HOST_DEVICE static + auto + get_slice(SliceCoord const& coord) + { + return TV_Partitioner{coord}; + } +}; + +template +CUTE_HOST_DEVICE +auto +make_tiler_impl(Layout_TV const&, + Tiler_MN const&) +{ + return TV_Tiler{}; +} + +} diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index f2505b35f2..dab8621e82 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -119,12 +119,16 @@ template CUTE_HOST_DEVICE constexpr T getv(EBO const&) { return {}; } +// This is a work around approach to solve a shared memory misalign issue (https://github.com/NVIDIA/cutlass/issues/1250). +// Will remove this work around implementation once the corresponding fix in compiler is released. +struct dummy_EBO_base {}; + // Specialization for types T that are not empty; // the "dynamic tuple leaf." Valid T here include int, // any other integral or floating-point type, // or any semiregular type for which std::is_empty_v is false. template -struct EBO +struct EBO : private dummy_EBO_base { CUTE_HOST_DEVICE constexpr EBO() : t_{} {} diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index 7031e7aba3..c2e7456e5f 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -78,6 +78,7 @@ using int_byte_t = typename int_byte::type; using uint1_t = cutlass::uint1b_t; using uint2_t = cutlass::uint2b_t; using uint4_t = cutlass::uint4b_t; +using uint6_t = cutlass::uint6b_t; using CUTE_STL_NAMESPACE::uint8_t; using CUTE_STL_NAMESPACE::uint16_t; using CUTE_STL_NAMESPACE::uint32_t; @@ -88,6 +89,7 @@ template struct uint_bit; template <> struct uint_bit< 1> { using type = uint1_t; }; template <> struct uint_bit< 2> { using type = uint2_t; }; template <> struct uint_bit< 4> { using type = uint4_t; }; +template <> struct uint_bit< 6> { using type = uint6_t; }; template <> struct uint_bit< 8> { using type = uint8_t; }; template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp index c566916569..892ec70686 100644 --- a/include/cute/numeric/numeric_types.hpp +++ b/include/cute/numeric/numeric_types.hpp @@ -73,6 +73,29 @@ using cutlass::uint4b_t; using cutlass::bin1_t; +using cutlass::float_ue4m3_t; +using cutlass::float_ue8m0_t; + +using cutlass::uint6b_t; +using cutlass::float_e2m1_t; +using cutlass::float_e2m3_t; +using cutlass::float_e3m2_t; + +using cutlass::type_erased_dynamic_float6_t; +using cutlass::type_erased_dynamic_float4_t; + +namespace detail { +using cutlass::detail::float_e2m1_unpacksmem_t; +using cutlass::detail::float_e2m3_unpacksmem_t; +using cutlass::detail::float_e3m2_unpacksmem_t; +using cutlass::detail::float_e2m3_unpack8bits_t; +using cutlass::detail::float_e3m2_unpack8bits_t; +using cutlass::detail::type_erased_dynamic_float4_unpacksmem_t; +using cutlass::detail::type_erased_dynamic_float6_unpacksmem_t; +}; + + + // // Print utility // @@ -133,4 +156,26 @@ pretty_print(float_e5m2_t t) { printf("%*.2f", 8, static_cast(t)); } + +template < + cutlass::detail::FpEncoding Encoding, + class Derived +> +CUTE_HOST_DEVICE +void +print(cutlass::float_exmy_base a) { + printf("%f", static_cast(a)); +} + +template < + cutlass::detail::FpEncoding Encoding, + class Derived +> +CUTE_HOST_DEVICE +void +pretty_print_float_exmy_base(cutlass::float_exmy_base t) { + printf("%*.2f", 8, static_cast(t)); +} + + } // namespace cute diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 4df82c72a1..3c42fd298c 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -284,6 +284,96 @@ recast_ptr(rmem_ptr

const& ptr) { return make_rmem_ptr(recast_ptr(ptr.get())); } + +// +// tmem_ptr -- a typed, word-addressed, non-dereferencable "pointer" +// + +template +struct tmem_ptr +{ + using value_type = remove_cv_t; + using element_type = T; + using reference = T; + + // Right-shift value for the offset scaling -- TMEM uses word-addressing + static constexpr int32_t OffsetShift = log_2(trait_ratio(sizeof_bits{}, sizeof_bits{})); + + CUTE_HOST_DEVICE constexpr + tmem_ptr(uint32_t addr = 0) : addr_(addr) {} + + CUTE_HOST_DEVICE constexpr + uint32_t const& get() const { + return addr_; + } + CUTE_HOST_DEVICE constexpr + uint32_t& get() { + return addr_; + } + + template + CUTE_HOST_DEVICE constexpr + value_type operator*() const { + static_assert(dependent_false, "Attempting to dereference a tmem_ptr, want raw_pointer_cast() for address instead?"); + return value_type{}; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](uint32_t const& i) const { return *(*this + i); } + + CUTE_HOST_DEVICE constexpr + tmem_ptr operator+(uint32_t const& i) const { + //return {addr_ + shiftr(i, OffsetShift)}; // Shift the offset for word-addressing + return {addr_ + rotr(i, OffsetShift)}; // Rotate the offset to keep subword indices in the unused high 8bits for debug + } + + // TMEM "Address" with active mask 0x007F.01FF + // The upper 16 bits, the 0x007F portion, refers to the 128 DP lanes + // The lower 16 bits, the 0x01FF portion, refers to the 512 COL lanes + union { + uint32_t addr_; + struct { + uint16_t col_; + uint8_t dp_; + uint8_t idx_; // Hijack the top 8bits for the sub-word idx to avoid an extra reg. + // Assert this is 0 on every access? + }; + }; +}; + +template +struct is_tmem : false_type {}; +template // Found the tmem +struct is_tmem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_tmem> : is_tmem {}; +template +constexpr bool is_tmem_v = is_tmem

::value; + +template +CUTE_HOST_DEVICE constexpr +tmem_ptr +make_tmem_ptr(uint32_t addr = 0) { + return tmem_ptr(addr); +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t +raw_pointer_cast(tmem_ptr const& ptr) { + return ptr.get(); +} + +// TMEM accounts for subword/superword elements already due to the offset shift based on sizeof_bits +// Thus, this is a trivial recast equivalent to reinterpret_cast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(tmem_ptr const& ptr) { + return tmem_ptr{ptr.addr_}; +} + + // // Display utilities // @@ -306,6 +396,14 @@ CUTE_HOST_DEVICE void print(rmem_ptr ptr) printf("rmem_"); print(ptr.get()); } + +template +CUTE_HOST_DEVICE void print(tmem_ptr ptr) +{ + printf("tmem_["); print(sizeof_bits::value); printf("b](0x%04x.%04x)", ptr.addr_ >> 16, ptr.addr_ & 0xFFFF); +} + + #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr ptr) @@ -325,6 +423,13 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr ptr) return os << "rmem_[" << int(sizeof_bits>::value) << "b]"; } + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, tmem_ptr ptr) +{ + return os << "tmem_[" << int(sizeof_bits::value) << "b](" << ptr.addr_ << ")"; +} + #endif // !defined(__CUDACC_RTC__) } // end namespace cute diff --git a/include/cute/tensor_zip.hpp b/include/cute/tensor_zip.hpp index 3b9b2ae3a2..279c4054d4 100644 --- a/include/cute/tensor_zip.hpp +++ b/include/cute/tensor_zip.hpp @@ -95,6 +95,9 @@ template struct is_smem> : conjunction...> {}; template struct is_gmem> : conjunction...> {}; +template +struct is_tmem> : conjunction...> {}; + // A tuple of Layouts that operates on each Layout symmetrically // The Layouts need to have compatible shapes and ranks. // The ZipLayout presents the intersection of the domain of its component Layouts. diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index 72c852e293..e6cc887adc 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -255,7 +255,12 @@ pretty_print(double v) { template CUTE_HOST_DEVICE void pretty_print(T t) { + constexpr auto has_print_exmy_base = cute::is_valid([](auto t) -> decltype(pretty_print_float_exmy_base(t)) {}, t); + if constexpr (has_print_exmy_base) { + pretty_print_float_exmy_base(t); + } else { printf(" "); print(t); + } } } // end namespace cute diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index e88597007c..f534f6cd8b 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -41,6 +41,7 @@ namespace cutlass { namespace arch { +constexpr int sm100_smem_capacity_bytes = 232448; #if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) /// Computes laneId within a warp @@ -93,6 +94,12 @@ struct Sm90 { static int const kMinComputeCapability = 90; }; + +struct Sm100 { + static int const kMinComputeCapability = 100; +}; + + /// Triggers a breakpoint on the device CUTLASS_DEVICE void device_breakpoint() { diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index ad4564cafc..b9bb70f9bc 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -36,12 +36,21 @@ #include #include +#include +#include + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) #define CUDA_BARRIER_ENABLED 1 #else #define CUDA_BARRIER_ENABLED 0 #endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) +#define CUTLASS_ARCH_TCGEN_ENABLED 1 +#endif + + namespace cutlass { /// @brief namespace arch { @@ -140,6 +149,15 @@ void initialize_barrier_array_pair_aligned(uint64_t *full_barriers_ptr, uint64_t } // namespace detail end + + +// There are 16 Named Barriers provided by Hardware starting in Hopper +// Their IDs are in the range 0-15 +// Number of threads syncing using the barrier must be a multiple of warp-size +// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads) +// may use it and conflict with other uses. + + // Enumerates the reserved named barriers to avoid potential conflicts // This enum class specifies the NamedBarriers reserved by CUTLASS. enum class ReservedNamedBarriers { @@ -148,6 +166,7 @@ enum class ReservedNamedBarriers { TransformBarrier = 3, StreamkBarrier0 = 4, StreamkBarrier1 = 5 + , TmemAllocBarrier = 6 , FirstUserBarrier = StreamkBarrier1 + 1 }; @@ -735,6 +754,152 @@ void cpasync_barrier_arrive_noinc(uint64_t const* smem_ptr) { //////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE +void umma_arrive(uint64_t const* smem_ptr) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + if (cute::elect_one_sync()) { + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + : + :"r"(bar_intptr)); + } +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +//UMMA arrive for MMA_2x1SM +CUTLASS_DEVICE +void umma_arrive_2x1SM(uint64_t const* smem_ptr) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + if (cute::elect_one_sync()) { + asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];" + : + :"r"(bar_intptr)); + } +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +// UMMA arrive for MMA_1sm + TMA_LOAD_MULTICAST combination +CUTLASS_DEVICE +void umma_arrive_multicast(uint64_t const* smem_ptr, uint16_t cta_mask) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + if(cute::elect_one_sync()) { + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); + } +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +// UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination +CUTLASS_DEVICE +void umma_arrive_multicast_2x1SM(uint64_t const* smem_ptr, uint16_t cta_mask) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); + } +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +// Temporary solution for sparse kernel. +// Will remove this when we done tightly elect_one wrap. +CUTLASS_DEVICE +void umma_arrive_multicast_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .b16 lo, hi;\n\t" + "mov.b32 {lo, hi}, %1;\n\t" + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t" + "}" + : + :"r"(bar_intptr), "r"(uint32_t(cta_mask))); +#elif defined(__CUDA_ARCH__) + CUTLASS_NOT_IMPLEMENTED(); +#endif +} + +// Temporary solution for sparse kernel. +// UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination +CUTLASS_DEVICE +void umma_arrive_multicast_2x1SM_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .b16 lo, hi;\n\t" + "mov.b32 {lo, hi}, %1;\n\t" + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t" + "}" + : + :"r"(bar_intptr), "r"(uint32_t(cta_mask))); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif +} + +// Always arrive on even SM of collaborating 2 SMs. +CUTLASS_DEVICE +void umma_arrive_2x1SM_sm0(uint64_t const* smem_ptr) { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr) & cute::Sm100MmaPeerBitMask; + asm volatile ( + "{\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [%0];\n\t" + "}" + : + : "r"(bar_intptr)); + +#else + asm volatile ("brkpt;\n" ::); +#endif +} + +CUTE_DEVICE static void fence_view_async_tmem_load() { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + asm volatile ( + "{\n\t" + "tcgen05.wait::ld.sync.aligned; \n" + "}" + ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +CUTE_DEVICE static void fence_view_async_tmem_store() { +#if defined(CUTLASS_ARCH_TCGEN_ENABLED) + asm volatile ( + "{\n\t" + "tcgen05.wait::st.sync.aligned; \n" + "}" + ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + + //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace arch } // end namespace cutlass diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index 5f842a4bb2..10b6af8a75 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -51,21 +51,32 @@ #endif #endif -#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 2) +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 2)) #define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED #endif ///////////////////////////////////////////////////////////////////////////////////////////////// -// SM90 Modifiable +// SM90 Modifiable TMA #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3)) #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1 - #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1 + #endif +#endif - #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) - #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED 1 +#if (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ == 8) + #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && \ + !defined(__CUDA_ARCH_FEAT_SM90_ALL) + #undef CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED #endif + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1000 && \ + !defined(__CUDA_ARCH_FEAT_SM100_ALL) + #undef CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED + #endif + #endif #endif @@ -79,7 +90,29 @@ #endif #endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM100, SM100a +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) + #define CUTLASS_ARCH_MMA_SM100_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM100_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1000) + #define CUTLASS_ARCH_MMA_SM100_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM100_ALL)) + #define CUTLASS_ARCH_MMA_SM100A_ENABLED 1 + #endif + #endif +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// + + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) +# define CUTLASS_ARCH_CLC_ENABLED +#endif + + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 2fcf2ee1c2..fb8c744050 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -129,6 +129,11 @@ struct OpClassWmmaTensorOp {}; /// Tag classifying operators as Tensor Core with structure sparse operations. struct OpClassSparseTensorOp {}; + +/// Tag classifying operators as Tensor Core with blockScaled +struct OpClassBlockScaledTensorOp {}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 0258d0d57c..e1e182827f 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -2567,7 +2567,6 @@ struct bit_not> { } }; - /// bit_xor template struct bit_xor> { @@ -2590,6 +2589,137 @@ struct bit_xor> { } }; +/// Fused and-popc-add +template +struct and_popc_add, Array, Array> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + + +/// Fused or-popc-add +template +struct or_popc_add, Array, Array> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + Array result; + or_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + Array result; + or_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + Array result; + or_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + +/// Fused xor-popc-add +template +struct xor_popc_add, Array, Array> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Operator overloads ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 3b089bf605..f9f2be8176 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -38,6 +38,8 @@ #include #include "cutlass/cutlass.h" #include "cutlass/trace.h" +#include + #if defined(__CUDACC_RTC__) #include #else @@ -49,6 +51,11 @@ # define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED #endif +#ifndef CUDA_ENABLE_PREFERRED_CLUSTER + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) + # define CUDA_ENABLE_PREFERRED_CLUSTER + #endif +#endif namespace cutlass { #ifndef NDEBUG @@ -78,7 +85,13 @@ struct ClusterLauncher { struct LaunchConfig { #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) cudaLaunchConfig_t launch_config; + + #if defined(CUDA_ENABLE_PREFERRED_CLUSTER) + constexpr static int numAttrs = 3; + #else + constexpr static int numAttrs = 2; + #endif cudaLaunchAttribute launch_attribute[numAttrs]; // Commonly used utility functions dim3 gridDim() { return launch_config.gridDim; } @@ -143,6 +156,7 @@ struct ClusterLauncher { size_t const smem_size = 0, cudaStream_t cuda_stream = 0, bool launch_with_pdl = false + , dim3 const fallback_cluster_dims = {0, 0, 0} ) { LaunchConfig cluster_launch_config; #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) @@ -151,9 +165,37 @@ struct ClusterLauncher { auto numAttrs = cluster_launch_config.numAttrs; launch_attribute[0].id = cudaLaunchAttributeClusterDimension; + + bool have_fallback = fallback_cluster_dims.x * fallback_cluster_dims.y * fallback_cluster_dims.z > 0; + + if (have_fallback) { + launch_attribute[0].val.clusterDim = {fallback_cluster_dims.x, fallback_cluster_dims.y, fallback_cluster_dims.z}; + CUTLASS_TRACE_HOST("ClusterLauncher: Setting fallback ClusterDims = " + "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n"); + } + else { + launch_attribute[0].val.clusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z}; CUTLASS_TRACE_HOST("ClusterLauncher: Setting ClusterDims = " "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + + } + +#if defined(CUDA_ENABLE_PREFERRED_CLUSTER) + if (have_fallback) { + if (cute::initialize_preferred_cluster_launch(nullptr, grid_dims, cluster_dims, fallback_cluster_dims)) { + launch_attribute[1].id = cudaLaunchAttributePreferredClusterDimension; + launch_attribute[1].val.preferredClusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z}; + CUTLASS_TRACE_HOST("ClusterLauncher: Setting preferred ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + } + } + else { + numAttrs--; + } +#endif + + // PDL attributes launch_attribute[numAttrs - 1].id = cudaLaunchAttributeProgrammaticStreamSerialization; launch_attribute[numAttrs - 1].val.programmaticStreamSerializationAllowed = 1; @@ -198,7 +240,7 @@ struct ClusterLauncher { return Status::kInvalid; } - CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " + CUTLASS_TRACE_HOST("ClusterLauncher: Launching GridDims = " "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), " "And ClusterDims = " "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); @@ -212,6 +254,53 @@ struct ClusterLauncher { #endif } + + // This is the method we expect to use going forward + // Launch a preferred cluster grid + static inline CUTLASS_HOST + Status launch_with_fallback_cluster( + dim3 const grid_dims, + dim3 const preferred_cluster_dims, + dim3 const fallback_cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void const* kernel, + void** kernel_params, + bool launch_with_pdl = false) { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, preferred_cluster_dims, + block_dims, smem_size, cuda_stream, launch_with_pdl, fallback_cluster_dims); + + auto launch_grid_dims = cluster_launch_config.gridDim(); + if (check_cluster_dims(launch_grid_dims, preferred_cluster_dims) != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); + return Status::kInvalid; + } + + auto init_status = init(kernel); + if (init_status != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); + return Status::kInvalid; + } + + CUTLASS_TRACE_HOST("ClusterLauncher: Launching \n\tGridDims = " + "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), " + "\n\tPreferred ClusterDims = " + "(" << preferred_cluster_dims.x << ", " << preferred_cluster_dims.y << ", " << preferred_cluster_dims.z << ")," + "\n\tFallback ClusterDims = " + "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n"); + + cutlass::arch::synclog_setup(); + cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params); + Return_Status(status); +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return Status::kInvalid; +#endif + } + + }; namespace detail { diff --git a/include/cutlass/conv/collective/builders/sm100_common.inl b/include/cutlass/conv/collective/builders/sm100_common.inl new file mode 100644 index 0000000000..b502466664 --- /dev/null +++ b/include/cutlass/conv/collective/builders/sm100_common.inl @@ -0,0 +1,193 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// +#pragma once + +#include "cutlass/layout/tensor.h" +#include "cute/atom/copy_traits_sm100_im2col.hpp" +#include "cutlass/arch/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/conv/collective/builders/sm90_common.inl" +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Collective tile traits struct that serves as a type list containing a tensor's mem layouts and atoms +template< + class GmemTiledCopy_, + class SmemLayoutAtom_, + class TmemLayoutAtom_ = void +> +struct Sm100ImplicitGemmTileTraits { + using GmemTiledCopy = GmemTiledCopy_; + using SmemLayoutAtom = SmemLayoutAtom_; + using TmemLayoutAtom = TmemLayoutAtom_; +}; + +template +constexpr auto +sm100_cluster_shape_to_im2col_tma_atom_A(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) { + static_assert(cute::rank(cluster_shape_mnk) == 3); + constexpr bool IsDynamicCluster = not cute::is_static_v; + + if constexpr (cute::size(atom_thr_id) == 2) { + if constexpr (!IsDynamicCluster) { + static_assert(cute::size<0>(cluster_shape_mnk) % 2 == 0, "Cluster shape not divisible by MMA size"); + if constexpr (cute::size<1>(cluster_shape_mnk) == 1) { + return cute::SM100_TMA_2SM_LOAD_IM2COL{}; + } + else { + return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{}; + } + } + else { + return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{}; + } + } + else if constexpr (size(atom_thr_id) == 1) { + if constexpr (!IsDynamicCluster) { + return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::size<1>(cluster_shape_mnk)); + } + else { + // In the case of dynamic cluster, multicast decision is not known at compile time. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::Int<2>{}); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported Configuration for SM100 TMA"); + } +} + +template +constexpr auto +sm100_cluster_shape_to_im2col_tma_atom_B(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) { + static_assert(cute::rank(cluster_shape_mnk) == 3); + constexpr bool IsDynamicCluster = not cute::is_static_v; + + if constexpr (cute::size(atom_thr_id) == 2) { + if constexpr (!IsDynamicCluster) { + static_assert(cute::size<0>(cluster_shape_mnk) % 2 == 0, "Cluster shape not divisible by MMA size"); + if constexpr (cute::size<0>(cluster_shape_mnk) == 2) { + return cute::SM100_TMA_2SM_LOAD_IM2COL{}; + } + else { + return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{}; + } + } + else { + return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{}; + } + } else if constexpr (size(atom_thr_id) == 1) { + if constexpr (!IsDynamicCluster) { + return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::size<0>(cluster_shape_mnk)); + } + else { + // In the case of dynamic cluster, multicast decision is not known at compile time. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::Int<2>{}); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported Configuration for SM100 TMA"); + } +} + +template< + class ElementA, + class ElementB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class KernelScheduleType +> +constexpr auto +sm100_make_tiled_mma() { + // MMA_2SM requested + if constexpr (cute::is_same_v) { + return cutlass::gemm::collective::detail::sm100_make_2sm_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>(); + } + // MMA_1SM requested + else if constexpr (cute::is_same_v) { + return cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>(); + } + // Auto scheduling requested + else if constexpr (cute::is_same_v) { + // Static cluster + if constexpr (cute::is_static_v) { + // For MMA_2SM we need a cluster shape that is multiple of 2x1 + // and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM + if constexpr (cute::size<0>(ClusterShape_MNK{}) % 2 == 0 && + cute::size<0>(TileShape_MNK{}) % 128 == 0) { + return cutlass::gemm::collective::detail::sm100_make_2sm_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>(); + } + else { + return cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>(); + } + // Dynamic cluster shape means we cannot assume we can use 2SM MMA + } + else { + return cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma< + ElementA, ElementB, ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>(); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported policy for SM100 collective builder."); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl new file mode 100644 index 0000000000..db1f7dae0a --- /dev/null +++ b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl @@ -0,0 +1,225 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// +#pragma once + +#include "cutlass/conv/collective/builders/sm100_common.inl" +#include "cutlass/conv/collective/builders/sm90_gmma_builder.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNKL, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNKL, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + ((sizeof(ElementA) * AlignmentA) % cutlass::gemm::collective::detail::tma_alignment_bytes == 0) && + ((sizeof(ElementB) * AlignmentB) % cutlass::gemm::collective::detail::tma_alignment_bytes == 0)>> { +private: + // For fprop, majorA = K, major B = K; + // For wgrad, majorA = MN, major B = MN; + // For dgrad, majorA = K, major B = MN; + static constexpr cute::UMMA::Major UmmaMajorA = + (ConvOp == conv::Operator::kWgrad) ? cute::UMMA::Major::MN : cute::UMMA::Major::K; + static constexpr cute::UMMA::Major UmmaMajorB = + (ConvOp == conv::Operator::kFprop) ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + using TileShape_MNK = decltype(cute::take<0,3>(TileShape_MNKL{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) + + static constexpr auto + get_tiled_mma_schedule() { + if constexpr (cute::is_same_v) { + return KernelImplicitTmaWarpSpecialized1SmSm100{}; + } + else if constexpr (cute::is_same_v) { + return KernelImplicitTmaWarpSpecialized2SmSm100{}; + } + else { + return KernelScheduleType{}; + } + } + + using TiledMmaSchedule = decltype(get_tiled_mma_schedule()); + using TiledMma = decltype(detail::sm100_make_tiled_mma()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + static constexpr auto + get_tma_atom_A() { + if constexpr (cute::is_same_v || + cute::is_same_v) { + static_assert(ConvOp == conv::Operator::kDgrad, "Operator+Schedule mismatch"); + return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(ClusterShape_MNK{}, AtomThrID{}); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(ClusterShape_MNK{}, AtomThrID{}); + } + else { + return cutlass::conv::collective::detail::sm100_cluster_shape_to_im2col_tma_atom_A(ClusterShape_MNK{}, AtomThrID{}); + } + } + + static constexpr auto + get_tma_atom_B() { + if constexpr (cute::is_same_v || + cute::is_same_v) { + static_assert(ConvOp == conv::Operator::kDgrad, "Operator+Schedule mismatch"); + return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cutlass::conv::collective::detail::sm100_cluster_shape_to_im2col_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}); + } + else { + return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}); + } + } + + // For wgrad kernel, tensor A uses tma tiled mode and tensor B uses tma im2col mode. + using GmemTiledCopyA = decltype(get_tma_atom_A()); + using GmemTiledCopyB = decltype(get_tma_atom_B()); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma, BlockTileA_M, BlockTileA_K>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma, BlockTileB_N, BlockTileB_K>()); + + // Calculate SMEM matrix A and B buffers' pipeline stages + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + static constexpr uint32_t SchedulerPipelineStageCount = 2; + static constexpr uint32_t CLCResponseSize = 16; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // LoadOrderBarrier = OrderedSequenceBarrier<1,2> + static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = SchedulerPipelineStageCount * sizeof(uint32_t); + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + LoadOrderBarrierStorage + + TmemDeallocStorage + + CLCThrottlePipelineStorage + + CLCResponseStorage + + TmemBasePtrsStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + static constexpr int PipelineStages = detail::compute_stage_count_or_override< + Sm100ReducedSmemCapacityBytes, ElementAMma, ElementBMma, SmemTileShape>(StageCountType{}); + + constexpr static int NumSpatialDimensions = detail::gmem_layout_tags_to_spatial_dims(); + + using DispatchPolicy = cutlass::conv::MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< + ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK>; + +public: + using CollectiveOp = cutlass::conv::collective::CollectiveConv< + DispatchPolicy, + TileShape_MNKL, + ElementA, + ElementB, + TiledMma, + detail::Sm100ImplicitGemmTileTraits, + detail::Sm100ImplicitGemmTileTraits + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/collective_builder.hpp b/include/cutlass/conv/collective/collective_builder.hpp index 278271d79a..e032f9599a 100644 --- a/include/cutlass/conv/collective/collective_builder.hpp +++ b/include/cutlass/conv/collective/collective_builder.hpp @@ -90,4 +90,5 @@ struct CollectiveBuilder { ///////////////////////////////////////////////////////////////////////////////////////////////// #include "builders/sm90_gmma_builder.inl" +#include "builders/sm100_umma_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/collective_conv.hpp b/include/cutlass/conv/collective/collective_conv.hpp index 8ecd6c9585..f0bb596fe0 100644 --- a/include/cutlass/conv/collective/collective_conv.hpp +++ b/include/cutlass/conv/collective/collective_conv.hpp @@ -59,4 +59,5 @@ struct CollectiveConv { ///////////////////////////////////////////////////////////////////////////////////////////////// #include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp" +#include "sm100_implicit_gemm_umma_warpspecialized.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/detail.hpp b/include/cutlass/conv/collective/detail.hpp index a986754653..af541a940f 100644 --- a/include/cutlass/conv/collective/detail.hpp +++ b/include/cutlass/conv/collective/detail.hpp @@ -167,6 +167,20 @@ sm90_dispatch_policy_to_stride_B() { } } + +template +constexpr auto +sm100_dispatch_policy_to_stride_A() { + return sm90_dispatch_policy_to_stride_A(); +} + +template +constexpr auto +sm100_dispatch_policy_to_stride_B() { + return sm90_dispatch_policy_to_stride_B(); +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Compute the lower/near corner, returning it as a cute::array in [W,H,D] order @@ -247,8 +261,11 @@ compute_lower_srt(ConvProblemShape const& problem_ } template struct is_im2col_load { static constexpr bool value = false; }; -template <> struct is_im2col_load { static constexpr bool value = true; }; -template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv::collective::detail diff --git a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp new file mode 100644 index 0000000000..cca462d5fd --- /dev/null +++ b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -0,0 +1,899 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/cluster.hpp" + +#include "cutlass/conv/detail.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/trace.h" + +#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) +# include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + conv::Operator ConvOp, + int Stages, + int NumSpatialDims, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) + class ElementA_, + class ElementB_, + class TiledMma_, + class TileTraitsA_, + class TileTraitsB_> +struct CollectiveConv< + MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape>, + TileShapeMNKL_, + ElementA_, + ElementB_, + TiledMma_, + TileTraitsA_, + TileTraitsB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape>; + using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) + using ElementA = ElementA_; + using ElementB = ElementB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; + using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; + using SmemLayoutAtomA = typename TileTraitsA_::SmemLayoutAtom; + using SmemLayoutAtomB = typename TileTraitsB_::SmemLayoutAtom; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; + static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; + // deducde the kernel facing stride tuple types based on the dispatch policy (spatial dim, algo, etc.) + using StrideA = decltype(detail::sm100_dispatch_policy_to_stride_A()); + using StrideB = decltype(detail::sm100_dispatch_policy_to_stride_B()); + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using TmaInternalElementA = cute::conditional_t>>; + using TmaInternalElementB = cute::conditional_t>>; + + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using ProblemShape = ConvProblemShape; + + CUTE_STATIC_ASSERT_V(evenly_divides(shape<0>(TileShape{}), tile_size<0>(TiledMma{})), "TileShape_M should be evenly divided by TiledMma_M"); + CUTE_STATIC_ASSERT_V(evenly_divides(shape<1>(TileShape{}), tile_size<1>(TiledMma{})) || (ConvOp == conv::Operator::kWgrad), "TileShape_N should be evenly divided by TiledMma_N"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + Step<_2,_1,_3>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + static constexpr bool is_im2col_A = detail::is_im2col_load::value; + static constexpr bool is_im2col_B = detail::is_im2col_load::value; + static constexpr bool is_strided_dgrad = ConvOp == conv::Operator::kDgrad && not is_im2col_A && not is_im2col_B; + + static constexpr int TileShapeMNKLRank = rank(TileShapeMNKL_{}); + // If rank > 3, TileL exists and it is GroupsPerTile. The kernel is grouped conv now. + static constexpr bool is_grouped_wgrad = ConvOp == conv::Operator::kWgrad && TileShapeMNKLRank > 3; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof(ElementA))) + + size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + ElementB const* ptr_B{nullptr}; + }; + +private: + + // Note that for fprop and non-strided dgrad kernel, the tma load mode is im2col for tensor A and tiled for + // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor + // B since operand A, B is swapped. + // For strided dgrad A and B are both tma tiled and not im2col + + template + static constexpr auto + get_tma_load_a_instance( + TensorA const& tensor_a, + ProblemShape const& problem_shape, + ClusterShapeVMNK const& cluster_shape_vmnk) { + + if constexpr (is_im2col_A) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + // gbasis strides for dgrad kernel need to be negated + cute::array stride_srt{}; + for (int i = 0; i < NumSpatialDimensions; ++i) { + stride_srt[i] = ConvOp == conv::Operator::kDgrad ? + -problem_shape.dilation[NumSpatialDimensions-1-i] : + problem_shape.dilation[NumSpatialDimensions-1-i]; + } + + return make_im2col_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk, + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + shape(stride_srt)); + } + // TMA tiled mode for tensor A in wgrad and strided dgrad + else { + return make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk); + } + } + + template + static constexpr auto + get_tma_load_b_instance( + TensorB const& tensor_b, + ProblemShape const& problem_shape, + ClusterShapeVMNK const& cluster_shape_vmnk) { + + if constexpr (is_im2col_B) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + return make_im2col_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk, + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + cute::reverse(shape(problem_shape.dilation))); + } + else { + return make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk); + } + } + +public: + + // Performs im2col transformations on the input of type ConvProblemShape + static constexpr auto + get_problem_shape_MNKL(ProblemShape const& problem_shape) { + if constexpr (is_im2col_A || is_im2col_B) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + } + + // Device-side kernel params + // + // Arguments has the untransformed problem shape from the user. + // Params will have the transformed problem shape. + struct Params { + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); + + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + // Assumption: StrideA is congruent with Problem_MK + // Select TMA load type according to convolution operator. + using TensorShapeA = cute::conditional_t; + + using TensorShapeB = cute::conditional_t; + + using TMA_A = decltype(get_tma_load_a_instance( + make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(TensorShapeA{}, StrideA{})), + ConvProblemShape{}, + ClusterLayout_VMNK{})); + + using TMA_B = decltype(get_tma_load_b_instance( + make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(TensorShapeB{}, StrideB{})), + ConvProblemShape{}, + ClusterLayout_VMNK{})); + + // Members + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + }; + + // + // Constructor + // + CUTLASS_DEVICE + CollectiveConv(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + // + // Methods + // + + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple + // tma desc creation depends on the original untransformed domain. + + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + + // Fill inferred cute strides from flat stride arrays + auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); + Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk); + auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk); + auto tma_load_a_fallback = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk_fallback); + auto tma_load_b_fallback = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk_fallback); + + static_assert(size(typename decltype(tma_load_a)::ThrID{}) == size(AtomThrShapeMNK{})); + static_assert(size(typename decltype(tma_load_b)::ThrID{}) == size(AtomThrShapeMNK{})); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + Arguments const& args) { + // Activation and Filter channel mode extents much match + bool implementable = true; + // channel mode is major + { + const bool check = problem_shape.stride_A[NumTensorDimensions-1] == 1; +#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) + if (not check) { + const auto offending_stride = + problem_shape.stride_A[NumTensorDimensions-1]; + std::ostringstream os; + os << "CollectiveConv::can_implement: " + "problem_shape.stride_A[NumTensorDimensions-1 = " + << (NumTensorDimensions-1) << "] = " + << offending_stride << " != 1"; + CUTLASS_TRACE_HOST( os.str() ); + } +#endif + implementable &= check; + } + + { + const bool check = problem_shape.stride_B[NumTensorDimensions-1] == 1; +#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) + if (not check) { + const auto offending_stride = + problem_shape.stride_B[NumTensorDimensions-1]; + std::ostringstream os; + os << "CollectiveConv::can_implement: " + "problem_shape.stride_B[NumTensorDimensions-1 = " + << (NumTensorDimensions-1) << "] = " + << offending_stride << " != 1\n"; + CUTLASS_TRACE_HOST( os.str() ); + } +#endif + implementable &= check; + } + + { + const auto & traversal_stride = problem_shape.traversal_stride; + for (auto stride: traversal_stride) { + implementable &= (stride >= 1 && stride <= 8); + } + } + + if constexpr (ConvOp == conv::Operator::kDgrad && not is_strided_dgrad) { + const auto & traversal_stride = problem_shape.traversal_stride; + for (auto stride: traversal_stride) { + implementable &= (stride == 1); + } + } + + constexpr int tma_alignment_bits = 128; + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + { + const bool check = cutlass::detail::check_alignment(shape_A_orig, StrideA{}); + if (not check) { + CUTLASS_TRACE_HOST("A shape and/or strides have alignment issue."); + } + implementable &= check; + } + + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + { + const bool check = cutlass::detail::check_alignment(shape_B_orig, StrideB{}); + if (not check) { + CUTLASS_TRACE_HOST("B shape and/or strides have alignment issue."); + } + implementable &= check; + } + + if (not implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + if (is_im2col_A || is_im2col_B) { + // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1] + constexpr int32_t offset_limit = 1 << (16 / NumSpatialDimensions); + auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; + for (int i = 0; i < problem_shape.RankS; ++i) { + // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array + implementable = implementable && (flt_data[i+1] * problem_shape.dilation[i] >= 0) + && (flt_data[i+1] * problem_shape.dilation[i] <= (offset_limit - 1)); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { + + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + { + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + } + + // When groups > 1, it should be a Grouped Conv. + if (problem_shape.groups > 1) { + implementable &= TileShapeMNKLRank > 3; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Only Grouped Conv can support groups > 1.\n"); + return false; + } + } + + // Only support Grouped Wgrad currently. + if constexpr (TileShapeMNKLRank > 3) { + implementable &= ConvOp == conv::Operator::kWgrad; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv Only support Grouped Wgrad currently.\n"); + return false; + } + } + + // Grouped Wgrad channel check. + if constexpr (is_grouped_wgrad) { + + int input_K = size<0>(problem_shape.get_shape_A()); + int input_C = size<0>(problem_shape.get_shape_B()); + + implementable &= input_K == input_C; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv's input K and input C do not match.\n"); + return false; + } + + int output_K = size<0>(problem_shape.get_shape_C()); + int output_C = size<1,0>(problem_shape.get_shape_C()); + + implementable &= input_K == output_K; + implementable &= input_C == output_C * problem_shape.groups; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's input and output K,C and groups do not match\n"); + return false; + } + + constexpr int Tile_N = size<1>(TileShape{}); + constexpr int GroupsPerTile = size<3>(TileShapeMNKL_{}); + + implementable &= Tile_N / GroupsPerTile == input_C / problem_shape.groups; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's Tile_N, GroupsPerTile and input_C, groups do not match.\n"); + return false; + } + } + + return true; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& mainloop_params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == mainloop_params.cluster_shape_fallback.x && cs.y == mainloop_params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mk, tBgB_nk, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mk(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _); + auto tensor_b_coord = get<1>(cta_coord_mnkl); + if constexpr (is_grouped_wgrad) { + // in grouped wgrad, tensor A = NZPQK, tensor B = NDHWC, tensor C = KTRSc, where C = G*c, c = channel_per_group = 8,16,32. + // CTA Tiling follows output tensor KTRSc. So cta_size_m = K/CTA_TILE_M. cta_size_n = T*R*S*ceil(c/CTA_TILE_N) = T*R*S*1 = T*R*S. + // tensor_a_coord = K_idx = cta_coord_m. + // tensor_b_coord = TRS_idx * C/CTA_TILE_N + C_idx = cta_coord_n * get<1,0>(shape(tBgB_nk) + cta_coord_m, + // because K == C and CTA_TILE_M == CTA_TILE_N => C_idx = K_idx = cta_coord_m. + tensor_b_coord = get<0>(cta_coord_mnkl) + get<1>(cta_coord_mnkl) * get<1,0>(shape(tBgB_nk)); + } + Tensor tBgB = tBgB_nk(_, tensor_b_coord, _); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if constexpr (is_strided_dgrad) { + // construct gemm-k tile coord for gB + auto [conv_k, flt_coord, out_coord] = *k_tile_iter; + auto gemm_k_tile = prepend(flt_coord, conv_k); // (k,s,r,t) + + // gA doesn't have a gemm-k (k,s,r,t) iterator mode because it's not an im2col tensor + auto offset_kqpzn = append(prepend(out_coord, _0{}),_0{}); // (k,q,p,z,n) + auto tAgA_offset = make_tensor(tAgA.data() + offset_kqpzn, tAgA.layout()); // (TMA, k) + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA_offset(_,conv_k), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,gemm_k_tile) , tBsB(_,write_stage)); + } + } + else { + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mk - The tiled tma tensor for input A + /// gB_nk - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + auto K_A = conditional_return(get<0>(K), K); + Tensor mA_mk = observed_tma_load_a_->get_tma_tensor(make_shape(M, K_A)); + Tensor mB_nk = observed_tma_load_b_->get_tma_tensor(make_shape(N, K)); + + // Tile the tensors and defer the slice + Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k) + Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mk = cta_mma.partition_A(gA_mk); // (MMA, MMA_M, MMA_K, m, k) + Tensor tCgB_nk = cta_mma.partition_B(gB_nk); // (MMA, MMA_N, MMA_K, n, k) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mk)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nk, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nk)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mk, gB_nk, // for scheduler + tAgA_mk, tBgB_nk, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b); // multicast masks + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + /* This helps avoid early exit of ctas in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB + > + CUTLASS_DEVICE auto + mma(MainloopPipeline pipeline, + MainloopPipelineState mainloop_pipe_consumer_state, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + int k_tile_count) + { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available (phase bit flips from mainloop_pipe_consumer_state.phase() value) + pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + CUTLASS_DEVICE auto + mma_init(TensorStorage& shared_tensors) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + TiledMma tiled_mma; + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = tiled_mma.make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = tiled_mma.make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + +private: + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 4437ae15f0..504575ad57 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -326,7 +326,9 @@ class ConvUniversalAdapter else { CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; - if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) { + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90 + || ConvKernel::ArchTag::kMinComputeCapability == 100 + ) { if constexpr (is_static_1x1x1) { device_kernel<<>>(params); launch_result = Status::kSuccess; diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index d9e20f46b3..b4bf8a5382 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -83,6 +83,37 @@ struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { "Persistent schedules not support for conv yet."); }; + + +// SM100 tensor op kernel schedule +struct KernelImplicitTmaWarpSpecializedSm100 { }; + +// Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100 +// but for opting into 1 or 2 SM atoms +struct KernelImplicitTmaWarpSpecialized1SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { }; +struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { }; + +struct KernelStridedDgradTmaWs1SmSm100 { }; +struct KernelStridedDgradTmaWs2SmSm100 { }; + +// n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop +template< + conv::Operator ConvOp_, + int Stages_, + int NumSpatialDimensions_, + class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>> +> +struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { + static constexpr int Stages = Stages_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + static constexpr Operator ConvOp = ConvOp_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelImplicitTmaWarpSpecializedSm100; + + static_assert(NumSpatialDimensions >= 1); +}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv diff --git a/include/cutlass/conv/kernel/conv_universal.hpp b/include/cutlass/conv/kernel/conv_universal.hpp index c9bd4b9fcf..af804df30e 100644 --- a/include/cutlass/conv/kernel/conv_universal.hpp +++ b/include/cutlass/conv/kernel/conv_universal.hpp @@ -61,4 +61,5 @@ class ConvUniversal { //////////////////////////////////////////////////////////////////////////////// #include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp" +#include "cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/conv_universal_dispatch.hpp b/include/cutlass/conv/kernel/conv_universal_dispatch.hpp new file mode 100644 index 0000000000..8507a17188 --- /dev/null +++ b/include/cutlass/conv/kernel/conv_universal_dispatch.hpp @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/workspace.h" + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +enum class DispatchMode { + VoidC // Select between voidC and non-voidC kernel based on beta scaling +}; + +// Dispatch between two ConvUniversal kernels +template +class ConvUniversalDispatch; + +//////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class MainloopWithC_, class EpilogueWithC_, + class MainloopVoidC_, class EpilogueVoidC_, + class TileScheduler_ +> +class ConvUniversalDispatch< + DispatchMode::VoidC, + ConvUniversal, + ConvUniversal, + cute::void_t +> : public ConvUniversal { +private: + using KernelWithC = ConvUniversal; + using KernelVoidC = ConvUniversal; + using FusionArguments = cute::remove_cvref_t; + +public: + // Mainloop derived types + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + + // Epilogue derived types + static_assert(not cute::is_void_v); + static_assert( cute::is_void_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + static_assert(cute::is_same_v); + + // TileID scheduler + static_assert(cute::is_same_v); + + static constexpr int SharedStorageSize = cute::max(KernelWithC::SharedStorageSize, KernelVoidC::SharedStorageSize); + + static_assert(KernelWithC::MaxThreadsPerBlock == KernelVoidC::MaxThreadsPerBlock); + + static_assert(KernelWithC::MinBlocksPerMultiprocessor == KernelVoidC::MinBlocksPerMultiprocessor); + + using Arguments = typename KernelWithC::Arguments; + + struct Params { + typename KernelWithC::Params withC; + typename KernelVoidC::Params voidC; + + void const* ptr_C; + decltype(FusionArguments{}.beta) beta; + decltype(FusionArguments{}.beta_ptr) beta_ptr; + decltype(FusionArguments{}.dBeta) dBeta; + cutlass::KernelHardwareInfo hw_info{}; + }; + + static size_t + get_workspace_size(Arguments const& args) { + return KernelWithC::get_workspace_size(args); + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) { + return KernelWithC::initialize_workspace(args, workspace, stream, cuda_adapter); + } + + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + return { + KernelWithC::to_underlying_arguments(args, workspace), + KernelVoidC::to_underlying_arguments(reinterpret_cast(args), workspace), + args.epilogue.ptr_C, + args.epilogue.thread.beta, + args.epilogue.thread.beta_ptr, + args.epilogue.thread.dBeta, + args.hw_info + }; + } + + static dim3 + get_grid_shape(Params const& params) { + return KernelWithC::get_grid_shape(params.withC); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + + bool run_voidC = false; + if (params.ptr_C == nullptr) { + run_voidC = true; + } + else if (params.beta_ptr == nullptr) { // Host scalar beta + run_voidC = params.beta == 0; + } + else if (get<0>(params.dBeta) == 0 && get<1>(params.dBeta) == 0) { // Device scalar beta + auto L = get<3>(append<4>(params.withC.problem_shape, _1{})); + if (get<2>(params.dBeta) == repeat_like(L, 0) || size(L) == 1) { // Non-batched + run_voidC = *params.beta_ptr == 0; + } + } + + if (run_voidC) { + return kernel_voidC(params.voidC, smem_buf); + } + else { + return KernelWithC::operator()(params.withC, smem_buf); + } + } + +private: + KernelVoidC kernel_voidC; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000000..90236e1fd9 --- /dev/null +++ b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -0,0 +1,911 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/arch/cluster_sm90.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/conv/detail.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/sm100_pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class ConvUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + + // Mainloop derived types + using ProblemShape = ProblemShape_; + using CollectiveMainloop = CollectiveMainloop_; + + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; + static constexpr bool is_grouped_wgrad = CollectiveMainloop::is_grouped_wgrad; + static constexpr bool IsComplex = false; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + // TileID scheduler + // CLC pipeline depth determines how many waves (stages-1) the scheduler can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = 2; + + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + + // Pipelines and pipeline states + static constexpr uint32_t AccumulatorPipelineStageCount = SchedulerPipelineStageCount; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = cutlass::PipelineDetail::PipelineCLCFetchAsyncPipelineState; + using CLCPipelineSharedStorage = cutlass::PipelineDetail::PipelineCLCFetchAsyncSharedStorage; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = cutlass::PipelineDetail::PipelineAsyncPipelineState; + using CLCThrottlePipelineSharedStorage = cutlass::PipelineDetail::PipelineAsyncSharedStorage; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = CLCPipelineSharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = CLCThrottlePipelineSharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + using ProblemShapeMNKL = decltype(CollectiveMainloop::get_problem_shape_MNKL(ProblemShape{})); + ProblemShapeMNKL problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + }; + + // + // Methods + // + // Map user facing arguments to device facing params + CUTLASS_HOST + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + + auto problem_shape_mnkl = CollectiveMainloop::get_problem_shape_MNKL(args.problem_shape); + + auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace, args.hw_info); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_mnkl, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + problem_shape_mnkl, + mainloop_params, + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace), + args.hw_info + }; + } + + CUTLASS_HOST + static bool + can_implement(Arguments const& args) { + bool implementable = true; + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + if constexpr (is_grouped_wgrad) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); + + implementable &= size<0>(cluster_shape) == 1 && size<0>(cluster_shape_fallback) == 1; + + if (!implementable) { + return false; + } + } + + return implementable; + } + + CUTLASS_HOST + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + auto linear_problem_shape_MNKL = cutlass::conv::detail::get_linearized_problem_shape_MNKL(args.problem_shape); + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, linear_problem_shape_MNKL, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + CUTLASS_HOST + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + auto linear_problem_shape_MNKL = cutlass::conv::detail::get_linearized_problem_shape_MNKL(args.problem_shape); + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace + ( + args.scheduler, workspace_ptr + workspace_offset, stream, linear_problem_shape_MNKL, + args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + + workspace_offset += TileScheduler::template get_workspace_size + ( + args.scheduler, linear_problem_shape_MNKL, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, + CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + CUTLASS_HOST + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + return TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape + ,params.hw_info + ); + } + + CUTLASS_HOST + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + auto problem_shape_MNKL = append<4>(params.problem_shape, _1{}); + auto [M, N, K, L] = problem_shape_MNKL; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : WarpCategory::Epilogue; + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue) // epilogue + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + if (cluster_size > 1) { + cute::cluster_arrive_relaxed(); + } + else { + __syncthreads(); + } + + uint32_t tmem_stage_ptrs[AccumulatorPipelineStageCount]; + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); + accumulator_pipeline.init_masks(cluster_shape); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, problem_shape_MNKL, TileShape{}, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto accumulators = TiledMma::make_fragment_C(acc_shape); + + int TmemColumnsPerAccumulatorTile = cutlass::detail::find_tmem_tensor_col_offset(accumulators); + pipeline_init_wait(cluster_size); + + if (is_participant.sched) { + + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + else if (is_participant.main_load) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mk = get<0>(load_inputs); + bool requires_clc_query = true; + + do { + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, TileShape{}, shape<3>(gA_mk)); + auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); + auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + else if (is_participant.epi_load) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + + CUTLASS_PRAGMA_UNROLL + for (int acc_stage = 0; acc_stage < AccumulatorPipelineStageCount; acc_stage++) { + tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; + } + auto mma_inputs = collective_mainloop.mma_init(shared_storage.tensors.mainloop); + do { + auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + // Accumulator stage slice + int acc_stage = accumulator_pipe_producer_state.index(); + accumulators.data() = tmem_stage_ptrs[acc_stage]; + + if (is_mma_leader_cta) { + mainloop_pipe_consumer_state = collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + mma_inputs, + k_tile_count + ); + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + + } + else if (is_participant.epilogue) { + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + CUTLASS_PRAGMA_UNROLL + for (int acc_stage = 0; acc_stage < AccumulatorPipelineStageCount; acc_stage++) { + tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; + } + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Accumulator stage slice after making sure allocation has been performed + int acc_stage = accumulator_pipe_consumer_state.index(); + accumulators.data() = tmem_stage_ptrs[acc_stage]; + + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulators, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + + do_tail_store = true; + } + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + } + +private: + + // Synchronization call. Blocks until barriers are initialized in shared memory. + CUTLASS_DEVICE + void + pipeline_init_wait(int cluster_size) { + if (cluster_size > 1) { + cute::cluster_wait(); + } + else { + __syncthreads(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index 577638ef65..046b3063a8 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -110,6 +110,48 @@ std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { return out << float(x); } + +inline +std::ostream & operator<<(std::ostream &out, float_e2m1_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, detail::float_e2m1_unpacksmem_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_e3m2_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_e2m3_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, detail::float_e3m2_unpacksmem_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, detail::float_e2m3_unpacksmem_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_ue8m0_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_ue4m3_t const &x) { + return out << float(x); +} + + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to enable formatted printing of CUTLASS scalar types to an ostream diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index b2240c51ce..2e62f84afd 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -330,6 +330,23 @@ struct CudaHostAdapter { void** kernel_params, int32_t kernel_index) const = 0; + + + /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. + /// This API is for preferred cluster launch; a preferred and a fallback cluster shapes are + /// considered for launch respectively. + virtual Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const fallback_cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + + + #if defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) /// Create a tensor map descriptor object representing im2col memory region. diff --git a/include/cutlass/detail/cluster.hpp b/include/cutlass/detail/cluster.hpp new file mode 100644 index 0000000000..d35765adeb --- /dev/null +++ b/include/cutlass/detail/cluster.hpp @@ -0,0 +1,99 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + + + +#include "cute/container/tuple.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/trace.h" +#include "cute/layout.hpp" // cute::make_shape +#include "cutlass/trace.h" // CUTLASS_TRACE_HOST + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +// Returns either ClusterShape, if it is static, or a Shape> populated with the +// x and y dimensions of `dynamic_cluster_shape`. +template +CUTLASS_HOST_DEVICE +static auto +select_cluster_shape(ClusterShape cluster_shape, dim3 dynamic_cluster_shape) { + return cute::conditional_return>( + make_shape(static_cast(dynamic_cluster_shape.x), static_cast(dynamic_cluster_shape.y), cute::Int<1>{}), + cluster_shape); +} + +template +CUTLASS_DEVICE +static auto +select_cluster_shape(ClusterShape cluster_shape) { + if constexpr (cute::is_static_v) { + return cluster_shape; + } + else { + dim3 dynamic_cluster_shape = cute::cluster_shape(); + return make_shape(static_cast(dynamic_cluster_shape.x), static_cast(dynamic_cluster_shape.y), cute::Int<1>{}); + } +} + +// Dynamic cluster shape can_implement rule +template +CUTLASS_HOST_DEVICE +bool +preferred_cluster_can_implement(dim3 cluster_shape, dim3 cluster_shape_fallback) { + bool implementable{true}; + + // Runtime cluster shape should satisfy MMA requirements + auto AtomThrShapeM = cute::size<0>(AtomThrShapeMNK{}); + implementable &= (cluster_shape.x > 0 && cluster_shape.y > 0 && cluster_shape.z > 0); + implementable &= (cluster_shape.x % AtomThrShapeM == 0); + + implementable &= (cluster_shape_fallback.x > 0 && cluster_shape_fallback.y > 0 && cluster_shape_fallback.z > 0); + implementable &= (cluster_shape_fallback.x % AtomThrShapeM == 0); + + // Only support pow2 runtime cluster shape for now + implementable &= ispow2(cluster_shape.x) && + ispow2(cluster_shape.y) && + ispow2(cluster_shape.z); + + implementable &= ispow2(cluster_shape_fallback.x) && + ispow2(cluster_shape_fallback.y) && + ispow2(cluster_shape_fallback.z); + + return implementable; +} + +} // namespace cutlass::detail + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp index d7a83d0474..840bae2e66 100644 --- a/include/cutlass/detail/collective.hpp +++ b/include/cutlass/detail/collective.hpp @@ -32,6 +32,7 @@ #include "cute/container/tuple.hpp" #include "cute/layout.hpp" // cute::size(shape) +#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::MXF4Format, cute::UMMA::MXF8F6F4Format ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { @@ -57,6 +58,114 @@ static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represen template using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype::type; + + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_runtime_f8f6f4() { + return (cute::is_same_v || + cute::is_same_v || + cute::is_same_v); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_f8f6f4_inputs() { + return ( + + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + + cute::is_same_v || + cute::is_same_v + + || cute::is_same_v || + cute::is_same_v || + cute::is_same_v + + ) && + ( + + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + + cute::is_same_v || + cute::is_same_v + + || cute::is_same_v || + cute::is_same_v || + cute::is_same_v + + ); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm100_mma_f8f6f4() { + return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 32) && is_sm10x_f8f6f4_inputs(); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_f8f6f4_element() { + return (cute::is_same_v + || cute::is_same_v + + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + + ); +} + + + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_block_scale_mxf8f6f4_input() { + // ElementType must be F8, F6, or F4 + return ( cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_block_scale_mxf4nvf4_input() { + // ElementType must be F4 + return ( cute::is_same_v || + cute::is_same_v + ); +} + +template +struct sm10x_block_scale_runtime_input_t { + static constexpr bool IsMxF8F6F4MmaInput = is_sm10x_block_scale_mxf8f6f4_input(); + static constexpr bool IsMxF4NvF4MmaInput = is_sm10x_block_scale_mxf4nvf4_input(); + + using Type = cute::conditional_t + >; +}; + + } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index 6f7d35b02b..ad5b51910e 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -237,7 +237,7 @@ struct LayoutAwareConvertImpl< } }; -// Specialization for UINT4 -> FPF16 with [02461357] value order +// Specialization for UINT4 -> FP16 with [02461357] value order template <> struct LayoutAwareConvertImpl< cutlass::uint4b_t, @@ -804,14 +804,15 @@ struct MixedInputUtils { { auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; asm volatile( "{\n" - " and .b32 %0, %2, %4 ;\n" \ - " and .b32 %1, %3, %5 ;\n" \ + " lop3 .b32 %0, %2, %4, %5, %6;\n" \ + " xor .b32 %1, %3, %5; \n" \ "}\n" : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) - : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0x7F7F7F00), "n"(0x7F7F7F7F) - ); + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) + ); } } CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 79a1f97bc3..fc9dd15ded 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -36,6 +36,8 @@ #include "cute/swizzle_layout.hpp" // cute::detail::get_swizzle_portion #include "cute/util/type_traits.hpp" #include "cute/arch/copy_sm90_tma.hpp" +#include "cute/arch/copy_sm100_tma.hpp" + #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" #include "cutlass/numeric_types.h" @@ -306,6 +308,8 @@ constexpr bool is_tma_copy_engine() { || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v ) { return true; } @@ -337,6 +341,16 @@ get_alignment_count_from_gmem_tiled_copy() { else { // For TMA tiled copies, we know the alignment has to be 128 bits if constexpr (is_tma_copy_engine()) { + + if constexpr ( cute::is_same_v::type, cutlass::detail::float_e2m1_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::float_e3m2_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::float_e2m3_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::type_erased_dynamic_float4_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::type_erased_dynamic_float6_unpacksmem_t> || + cutlass::gemm::collective::detail::is_sm10x_f8f6f4_element() && cute::is_same_v::type, uint8_t>) { + return 128; + } + // For sparse MMA, alignment in logical elements is increased by sparsity factor if constexpr (cute::is_sparse_v) { return 128 / sizeof_bits::value * ElementMma::sparsity; @@ -353,9 +367,17 @@ get_alignment_count_from_gmem_tiled_copy() { // Return alignment bit requirements for the GEMM inputs. template < class ElementType + , bool IsF8F6F4SubBytes=false > constexpr int get_input_alignment_bits() { + + if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 4) { + return 64 * 8; + } + else if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 6) { + return 96 * 8; + } return 128; } @@ -363,6 +385,12 @@ get_input_alignment_bits() { template constexpr int get_output_alignment_bits() { + + if constexpr (sizeof_bits::value == 6) { + // U6 format : The inner tensor size dimension must be a multiple of 96B. + return 96 * 8; + } + return 128; } diff --git a/include/cutlass/detail/mma.hpp b/include/cutlass/detail/mma.hpp index 0f2d0e1bd3..b4cbd3864a 100644 --- a/include/cutlass/detail/mma.hpp +++ b/include/cutlass/detail/mma.hpp @@ -47,17 +47,33 @@ template struct IsSparseTensorOp> : cute::true_type { }; + +template +struct IsBlockScaledTensorOp : cute::false_type { }; + +// TiledMma for blockScaled must have FrgTypeSFA +template +struct IsBlockScaledTensorOp> + : cute::true_type { }; + + // The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. template struct get_operator_class { static constexpr bool is_sparse_op = IsSparseTensorOp::value; + static constexpr bool is_block_scaled_op = IsBlockScaledTensorOp::value; + // All tensorop operations have atom shape's M >= 8 static constexpr bool is_tensor_op = cute::size<0>(typename TiledMma::AtomShape_MNK{}) >= 8; using type = cute::conditional_t< is_tensor_op, cute::conditional_t< is_sparse_op, cutlass::arch::OpClassSparseTensorOp, + cute::conditional_t< + is_block_scaled_op, + cutlass::arch::OpClassBlockScaledTensorOp, cutlass::arch::OpClassTensorOp + > >, cutlass::arch::OpClassSimt >; diff --git a/include/cutlass/detail/sm100_blockscaled_layout.hpp b/include/cutlass/detail/sm100_blockscaled_layout.hpp new file mode 100644 index 0000000000..cba49d6455 --- /dev/null +++ b/include/cutlass/detail/sm100_blockscaled_layout.hpp @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Blocked Scale configs specific for SM100 BlockScaled MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm100BlockScaledBasicChunk { + + using Blk_MN = _128; + using Blk_SF = _4; + + using SfKMajorAtom = Layout< Shape< Shape<_32,_4>, Shape, _4>>, + Stride, Stride< _0, _1>>>; + using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_32,_4>>, + Stride, Stride<_16,_4>>>; + using SfAtom = cute::conditional_t; +}; + +template +struct Sm100BlockScaledConfig { + // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm100BlkScaledChunk = Sm100BlockScaledBasicChunk; + using Blk_MN = typename Sm100BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm100BlkScaledChunk::Blk_SF; + using SfAtom = typename Sm100BlkScaledChunk::SfAtom; + + using LayoutSF = decltype(blocked_product(SfAtom{}, make_layout( make_shape(int32_t(0), int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}, int32_t(0))))); + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSF{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template < class ProblemShape, class LayoutSFA = LayoutSF> + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFA(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + constexpr int MMA_NSF = TiledMma::K / SFVecSize; + // Basic storage block for new Scaling Factor Layouts + using mnBasicBlockShape = Shape<_32,_4>; + using mnBasicBlockStride = Stride<_16,_4>; + using kBasicBlockShape = Shape, Int>; + using kBasicBlockStride = Stride<_0, _1>; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_MN = typename Sm100BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm100BlkScaledChunk::Blk_SF; + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + + using TL_VMNK = typename TiledMma::ThrLayoutVMNK; + constexpr TL_VMNK tl_vmnk{}; + constexpr int MMA_M = cute::size<0>(TileShape_MNK{}) / cute::size<0>(tl_vmnk); + using mma_SFA_shape = decltype( make_shape( prepend(Int{}/Blk_MN{}, mnBasicBlockShape{}), kBasicBlockShape{})); + using mma_SFA_stride = decltype(make_stride( prepend( Blk_Elems{}, mnBasicBlockStride{}), kBasicBlockStride{})); + using sSFA_shape = decltype( make_shape( mma_SFA_shape{}, _1{}, make_shape( Blk_SF{}/Int{}, Int(TileShape_MNK{}) / SFVecSize / Blk_SF{}>{}))); + using sSFA_stride = decltype(make_stride(mma_SFA_stride{}, _0{}, make_stride( Int{}, Int{}))); + using SmemLayoutAtomSFA = decltype(make_layout(sSFA_shape{}, sSFA_stride{})); + return SmemLayoutAtomSFA{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFB(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + constexpr int MMA_NSF = TiledMma::K / SFVecSize; + // Basic storage block for new Scaling Factor Layouts + using mnBasicBlockShape = Shape<_32,_4>; + using mnBasicBlockStride = Stride<_16,_4>; + using kBasicBlockShape = Shape, Int>; + using kBasicBlockStride = Stride<_0, _1>; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_MN = typename Sm100BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm100BlkScaledChunk::Blk_SF; + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + + using TL_VMNK = typename TiledMma::ThrLayoutVMNK; + constexpr TL_VMNK tl_vmnk{}; + constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); + // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. + // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. + constexpr int MMA_N_SFB = cutlass::ceil_div(MMA_N, Blk_MN{}) * Blk_MN{}; + using mma_SFB_shape = decltype(make_shape( prepend( Int{}/Blk_MN{}, mnBasicBlockShape{}), kBasicBlockShape{})); + using mma_SFB_stride = decltype(make_stride(prepend( Blk_Elems{}, mnBasicBlockStride{}), kBasicBlockStride{})); + using sSFB_shape = decltype( make_shape( mma_SFB_shape{}, _1{}, make_shape( Blk_SF{}/Int{}, Int(TileShape_MNK{}) / SFVecSize / Blk_SF{}>{}))); + using sSFB_stride = decltype(make_stride(mma_SFB_stride{}, _0{}, make_stride( Int{}, Int{}))); + using SmemLayoutAtomSFB = decltype(make_layout(sSFB_shape{}, sSFB_stride{})); + return SmemLayoutAtomSFB{}; + } +}; + + +template +struct Sm100BlockScaledOutputConfig { + // We are creating the SFD tensors' layouts in the collective. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm100BlkScaledChunk = cutlass::detail::Sm100BlockScaledBasicChunk; + using Blk_MN = typename Sm100BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm100BlkScaledChunk::Blk_SF; + using SfAtom = typename Sm100BlkScaledChunk::SfAtom; + + using LayoutKMajorSF = decltype(blocked_product(SfAtom{}, make_layout(make_shape (int32_t(0), int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}, int32_t(0))))); + + static_assert(major == UMMA::Major::K, "Only K-major scalefactor output is supported"); + using LayoutSF = LayoutKMajorSF; + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFD() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFC. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFD(ProblemShape problem_shape, LayoutSFD layout_sfc = LayoutSFD{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_2,_1,_3>{}); + } +}; + +//// Describe the Scalefactor Tensor without VectorSize +struct Sm100BlockScaledTensorConfig { + // k-major order + // The blockscaled tensor does not need to know vectorsize + using Blk_M = _128; + using Blk_N = _4; + using SfAtom = Layout< Shape< Shape<_32,_4>, Shape<_4>>, + Stride, Stride<_1>>>; + + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape(ProblemShape problem_shape) { + auto problem_shape_MNL = append<3>(problem_shape, 1); + auto [M, N, L] = problem_shape_MNL; + return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_2,_1,_3>{}); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/detail/sm100_tmem_helper.hpp b/include/cutlass/detail/sm100_tmem_helper.hpp new file mode 100644 index 0000000000..f12bac12dc --- /dev/null +++ b/include/cutlass/detail/sm100_tmem_helper.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief TMEM Accumulator Helpers for SM100 +*/ + +#pragma once + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + + +namespace cutlass::detail{ +constexpr uint32_t TmemColMask = 0x0000'FFFF; + +template +CUTE_HOST_DEVICE +static constexpr auto find_tmem_tensor_col_offset(TmemTensor tensor) { + using namespace cute; + return cosize(recast(tensor).layout()) & TmemColMask; +} + +template +CUTE_HOST_DEVICE +static constexpr auto make_sm100_accumulator(TiledMma tiled_mma, AccumulatorShape acc_shape, EpilogueTile epilogue_tile) { + using namespace cute; + static_assert(rank(acc_shape) == 3 || (rank(acc_shape) == 4 && IsOverlappingAccum == false), + "Expect a rank >= 3 accumulator shape compatible with an SM100 tiled mma, Overlapping accumulators is only available for non-complex kernels"); + if constexpr (IsOverlappingAccum) { + Tensor accumulators_tmp = TiledMma::make_fragment_C(append(acc_shape, Int<2>{})); + return make_tensor( + accumulators_tmp.data(), + shape(accumulators_tmp), + replace<3>( + stride(accumulators_tmp), + Int<(256 - size<1>(EpilogueTile{})) * stride<0, 1>(accumulators_tmp.layout())>{})); + } else { + return TiledMma::make_fragment_C(append( + acc_shape, + Int{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) + } +} +} // namespace cutlass::detail diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl new file mode 100644 index 0000000000..882a6e2f41 --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -0,0 +1,1052 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +// + +// + +#include "cute/layout.hpp" // cute::Shape +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cutlass/arch/mma.h" // cutlass::arch::OpClassTensorOp, cutlass::OpClassSparseTensorOp +#include "cute/atom/copy_traits_sm100.hpp" +#include "cute/atom/mma_traits_sm100.hpp" +#include "cute/util/type_traits.hpp" // cute::is_same_v + +#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false +#include "cutlass/detail/layout.hpp" +#include "cutlass/numeric_size.h" // cutlass::bytes_to_bits +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" // detail::is_sfd_epilogue_v +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +sm100_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // Get the max contiguous tile usable by TMA + [[maybe_unused]] auto tma_tile = cute::transform(EpilogueTile_MN{}, + [](auto const& epi_tile) { + // assumes get<0>(epi_tile) is coalesced and unit stride + return size<0>(coalesce(right_inverse(make_layout(get<0>(epi_tile))))); + }); + + // ColMajor C/D (M-major) + if constexpr (cutlass::detail::is_major<0>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::sm100_smem_selector< + UMMA::Major::MN, Element, decltype(get<0>(tma_tile)), decltype(get<1>(tma_tile)) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (cutlass::detail::is_major<1>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::sm100_smem_selector< + UMMA::Major::K , Element, decltype(get<0>(tma_tile)), decltype(get<1>(tma_tile)) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. +template < + class OpClass, + class CtaTileShape_MNK, + class EpilogueTileType, + class TmemWarpShape_MN, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class FusionOp +> +constexpr auto +sm100_compute_tile_shape_or_override() { + using namespace cute; + + if constexpr (cute::is_same_v && + cute::is_same_v && + size<1>(CtaTileShape_MNK{}) == 256) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int DpFull = 32; + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + // Note: + // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. + // This is a general workable epi_tile_N which does not promise best perf. + return make_tile(Int{}, Int<128>{}); + } + else if constexpr (cute::is_same_v) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int CtaN = size<1>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int WarpN = size<1>(TmemWarpShape_MN{}); + constexpr bool DisableSource = is_void_v; + constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); + + constexpr int DpFull = 32; // tmem datapaths in 1 subpartition + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf + // Epilogues w/o residual load are less sensitive to smem allocation + // Target a fixed amount of compute per epilogue iteration + if (DisableSource) { + if (MaxBits == 4) { + // Make epilogue tile larger to reduce the epilogue iterations. + // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + constexpr int ComputeElts = 8192; + return ComputeElts / M; + } + constexpr int ComputeElts = 4096; + return ComputeElts / M; + } + // Epilogues w/ residual load are more sensitive to smem allocation + // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + else { + if (MaxBits == 32) { + return (CtaM > 64 && CtaN <= 128) ? 16 : 32; + } + // Per-column scaling is high register pressure, reduce tile to prevent spills + else if (FusionOp::IsPerColScaleSupported) { + return 32; + } + else if (MaxBits == 16) { + return (CtaN <= 128) ? 32 : 64; + } + else { + return 64; + } + } + }(); + constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); + static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); + + // stride by tmem warp layout and return a by-mode tiler + auto tile_m = Layout>{}; + auto tile_n = Layout,Int< WarpN>>, + Stride,Int>>{}; + + return make_tile(tile_m, coalesce(tile_n)); + } + else if constexpr (cute::is_tuple::value) { + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(TmemWarpShape_MN{} == Shape<_2,_2>{} && (M == 32 || M == 64) || + TmemWarpShape_MN{} == Shape<_4,_1>{} && (M == 64 || M == 128), "Unsupported tile shape"); + static_assert(N % 8 == 0, "Unsupported tile shape"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} + +template +static constexpr bool IsPtrArrayDispatchPolicy = + cute::is_same_v || + cute::is_same_v; + + +template < + class CtaTileShape_MNK, + class EpilogueTile_MN, + class ElementC, + class ElementD, + class Schedule +> +constexpr auto +sm100_get_tma_dispatch_policy() { + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{}))); + constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup; + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = sizeof_bits_v > 8; + constexpr bool DelayTmaStore = false; + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (detail::IsPtrArrayDispatchPolicy) { + return Sm100PtrArrayTmaWarpSpecialized{}; + } + else + { + return Sm100TmaWarpSpecialized{}; + } +} + +/* + * Returns the TMEM_LOAD copy op to be used for the epilogue + * Returned TMEM_LOAD op is such that the thread-value ownership matches the widest available + * smem storage vectorization, subject to the constraints of data types and gmem layout + * Selected op also maximizes the TMEM_LOAD shape in order to minimize TMEM_LOADs issued, + * subject to the constraint of the provided per-warp tmem subpartition shape +**/ +template +constexpr auto +sm100_get_tmem_load_op() { + using namespace cute; + + // Number of datapaths (dp) available in this warp's tmem subpartition. + // If only 16dp are available then we must use 16dp TMEM_LOAD variants + // otherwise we prefer 32dp variants as those have higher throughput + + // For those fused patterns which have RowReduction or RowBroadcast + // 16dp tmem load op can effectively reduce the usage of registers & shuffle instrs + // Compared to TMEM_LOAD throughput, it's more critical + constexpr int num_dp = size<0>(TmemShape_MN{}); + static_assert(num_dp == 16 || num_dp == 32, "Unsupported tmem datapath count"); + + // Number of columns in this tmem subpartition, in bits + // Used to select the widest cross variant TMEM_LOAD available + constexpr int num_col_bits = size<1>(TmemShape_MN{}) * sizeof_bits_v; + + // Layout information, determines max available smem store vectorization + // For M-major layouts we tend to target stmatrix_t (UMMA stores tmem accumulator in N-Major) + constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{}); + constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{}); + static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); + + // dispatch on data types as this determines the correspondence + // between TMEM_LOAD thread-bit ownership patterns and logical values + if constexpr (sizeof_bits_v == 32 && sizeof_bits_v == 32) { + if constexpr (num_dp == 16) { + if constexpr (is_m_major) { + return TMEM::op_repeater(); // 32b stores to smem + } + else { + return TMEM::op_repeater(); // stmatrix_n + // return TMEM::op_repeater(); // 64b stores to smem + // return TMEM::op_repeater(); // 128b stores to smem + } + } + else { + return TMEM::op_repeater(); // 32b or 128b stores to smem + } + } + + else if constexpr (sizeof_bits_v == 32 && sizeof_bits_v == 16) { + if constexpr (num_dp == 16) { + if constexpr (is_m_major) { + return TMEM::op_repeater(); // stmatrix_t + } + else { + return TMEM::op_repeater(); // stmatrix_n + // return TMEM::op_repeater(); // 128b stores to smem + } + } + else { + if constexpr (is_m_major) { + return TMEM::op_repeater(); // stmatrix_t + } + else { + return TMEM::op_repeater(); // 128b stores to smem + } + } + } + + // For int8 kernels where accumulation is 32b but result store may be back to int8 + else if constexpr (sizeof_bits_v == 32 && sizeof_bits_v == 8) { + if constexpr (num_dp == 16) { + if constexpr (is_m_major) { + return TMEM::op_repeater(); // stmatrix_t m16n8 + } + else { + // return TMEM::op_repeater(); // 16b stores to smem + return TMEM::op_repeater(); // 128b stores to smem + } + } + else { + // To use the HW instruction to find amax along the row/column of acc, the TMEM_LOAD pattern needs to be 32dp32bit. + return TMEM::op_repeater(); // 128b stores to smem + } + } + + // For 16b accumulation we use pack16b TMEM_LOAD variants as UMMA stores these values sparsely in tmem + else if constexpr (sizeof_bits_v == 16 && sizeof_bits_v == 16) { + if constexpr (num_dp == 16) { + if constexpr (is_m_major) { + return TMEM::op_repeater(); // stmatrix_t + } + else { + return TMEM::op_repeater(); // stmatrix_n + // return TMEM::op_repeater(); // 128b stores to smem + } + } + else { + if constexpr (is_m_major) { + return TMEM::op_repeater(); // stmatrix_t + } + else { + return TMEM::op_repeater(); // 128b stores to smem + } + } + } + // For complex TF32 kernels + else if constexpr (sizeof_bits_v == 64 && sizeof_bits_v == 64) { + if constexpr (num_dp == 16) { + return TMEM::op_repeater(); + } + else { + return TMEM::op_repeater(); + } + } + // For narrow precision output + else if constexpr (sizeof_bits_v == 32 && sizeof_bits_v == 6) { + static_assert(num_dp == 32); + return TMEM::op_repeater(); + } + else if constexpr (sizeof_bits_v == 32 && sizeof_bits_v == 4) { + static_assert(num_dp == 32); + return TMEM::op_repeater(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported data types"); + } +} + +// Selects the largest vectorized smem store atom available +// subject to constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership +template +constexpr auto +sm100_get_smem_store_op() { + using namespace cute; + + [[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{}); + [[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{}); + static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); + + // Check for TMEM_LOAD layouts that match the thread-value ownership pattern of stmatrix + // TODO: check copy vectorization instead! + constexpr bool use_stmatrix_m8n8_4x = + (sizeof_bits_v == 32 && sizeof_bits_v == 32 && is_n_major && + ( cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v ) ) || + (sizeof_bits_v == 32 && sizeof_bits_v == 16 && + ( cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v ) ) || + (sizeof_bits_v == 16 && sizeof_bits_v == 16 && + ( cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v )); + [[maybe_unused]] constexpr bool use_stmatrix_m16n8_4x = + (sizeof_bits_v == 32 && sizeof_bits_v == 8 && is_m_major && + ( cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v ) ); + + // 1x TMEM_LOAD doesn't have enough values to use largest stmatrix variants + [[maybe_unused]] constexpr bool use_stmatrix_m8n8_2x = + (sizeof_bits_v == 32 && sizeof_bits_v == 32 && is_n_major && + cute::is_same_v ) || + (sizeof_bits_v == 32 && sizeof_bits_v == 16 && + cute::is_same_v ) || + (sizeof_bits_v == 16 && sizeof_bits_v == 16 && + cute::is_same_v ); + [[maybe_unused]] constexpr bool use_stmatrix_m16n8_2x = + (sizeof_bits_v == 32 && sizeof_bits_v == 8 && is_m_major && + cute::is_same_v ); + [[maybe_unused]] constexpr bool use_stmatrix_m16n8_1x = + (sizeof_bits_v == 32 && sizeof_bits_v == 8 && is_m_major && + cute::is_same_v ); + + if constexpr (use_stmatrix_m8n8_4x) { + if constexpr (is_n_major) { + return SM90_U32x4_STSM_N{}; + } + else if constexpr (is_m_major) { + return SM90_U16x8_STSM_T{}; + } + } + else if constexpr (use_stmatrix_m8n8_2x) { + if constexpr (is_n_major) { + return SM90_U32x2_STSM_N{}; + } + else if constexpr (is_m_major) { + return SM90_U16x4_STSM_T{}; + } + } + else if constexpr (use_stmatrix_m16n8_4x) { + return SM100_U8x16_STSM_T{}; + } + else if constexpr (use_stmatrix_m16n8_2x) { + return SM100_U8x8_STSM_T{}; + } + else if constexpr (use_stmatrix_m16n8_1x) { + return SM100_U8x4_STSM_T{}; + } + else { + // auto-vectorizing store + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } +} + +template +constexpr auto +sm100_get_register_transform_op() { + using namespace cute; + + [[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{}); + [[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{}); + static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); + + if constexpr (sizeof_bits_v == 4 && is_m_major) { + return SM50_Shuffle_U32_2x2Trans_XOR1{}; + } + else { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } +} + +// Selects the largest vectorized smem load atom available +// subject to constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership +template +constexpr auto +sm100_get_smem_load_op() { + using namespace cute; + + // Reuse the logic from smem store selector + using SmemStoreOp = decltype(sm100_get_smem_store_op< + GmemStrideTypeC, ElementC, ElementAccumulator, AccLoadOp>()); + + if constexpr (cute::is_same_v) { + return SM75_U32x4_LDSM_N{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U16x8_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U32x2_LDSM_N{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U16x4_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM100_U8x16_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM100_U8x8_LDSM_T{}; + } + else { + // auto-vectorizing load + return AutoVectorizingCopyWithAssumedAlignment{}; + } +} + +template +constexpr auto +sm100_get_gmem_load_op() { + + if constexpr (detail::is_im2col_mode) { + return SM90_TMA_LOAD_IM2COL{}; + } + else { + + return SM90_TMA_LOAD{}; + } +} + +template +constexpr auto +sm100_get_gmem_store_op() { + + if constexpr (detail::is_im2col_mode) { + return SM90_TMA_STORE_IM2COL{}; + } + else { + + return SM90_TMA_STORE{}; + } +} + +// aux fusion callbacks builder for sm100 tma epilogue +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class CtaTileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp +> +struct CallbacksBuilder< + Sm100TmaWarpSpecialized, + FusionOp, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm100_get_smem_store_op< + GmemStrideTypeAux, typename FusionOp::ElementAux, ElementAccumulator, AccLoadOp>()); + using CopyOpS2R = decltype(detail::sm100_get_smem_load_op< + GmemStrideTypeAux, typename FusionOp::ElementAux, ElementAccumulator, AccLoadOp>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm100TmaWarpSpecialized, + FusionOp, CtaTileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class CtaTileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp +> +struct CallbacksBuilder< + Sm100TmaWarpSpecialized, + FusionOp, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && sizeof_bits_v == 1> +> { + using Callbacks = fusion::FusionCallbacks< + Sm100TmaWarpSpecialized, + FusionOp, CtaTileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem + >; +}; + +// aux fusion callbacks builder for sm100 direct store epilogue +template < + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm100NoSmemWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported)> // only one aux tensor +> { + using Callbacks = fusion::FusionCallbacks< + Sm100NoSmemWarpSpecialized, FusionOp, TileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux tensor doesn't use tma + >; +}; + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the fusion operation performed and the dispatch policy to use. +template < + class OpClass, + class CtaTileShape_MNK, + class EpilogueTileType, + class TmemWarpShape_MN, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOpOrCallbacks +> +struct Sm100TmaBuilderImpl { + // Passing void C disables source load + smem allocation + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using CopyOpS2G = decltype(detail::sm100_get_gmem_store_op()); + using CopyOpG2S = decltype(detail::sm100_get_gmem_load_op()); + + using FusionOp = conditional_t, + FusionOpOrCallbacks, epilogue::fusion::FusionOperation>; + + using EpilogueTile_MN = decltype(detail::sm100_compute_tile_shape_or_override< + OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN, + ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp>()); + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{}))); + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + + using InternalSmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using InternalSmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + using DispatchPolicy = decltype(detail::sm100_get_tma_dispatch_policy< + CtaTileShape_MNK, EpilogueTile_MN, ElementC_, ElementD, Schedule>()); + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using FusionCallbacks = + typename CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp + >::Callbacks; + + using CollectiveOp = + cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + FusionCallbacks, + AccLoadOp, + CopyOpG2S, + decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm100_get_smem_load_op()), + CopyOpS2G, + decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm100_get_smem_store_op()), + decltype(detail::sm100_get_register_transform_op()) + >; +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +// No smem builder +template < + class CtaTileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + CtaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOpOrCallbacks, + cute::enable_if_t || + cute::is_same_v >> { + + static_assert(cute::is_same_v, "Epilogue subtiling requires smem"); + static_assert(cute::sizeof_bits_v != 4 and cute::sizeof_bits_v != 6, "Output element requires smem"); + + static constexpr bool DisableSource = cute::is_void_v; + using ElementC = cute::conditional_t; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t; + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using FusionOp = conditional_t, + FusionOpOrCallbacks, epilogue::fusion::FusionOperation>; + + // use a 4x2 division to select tmem load shape in order to maintain compatability with both (4,1) and (2,2) layouts + using EpilogueTile = decltype(take<0,2>(CtaTileShape_MNK{})); + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_2>{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + + using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized; + + using AlignmentCType = Int; + using AlignmentDType = Int; + + static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + static constexpr thread::ScaleType::Kind ScaleType = DisableSource ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + using FusionCallbacks = cute::conditional_t< + IsDefaultFusionOp::value, + // Legacy codepath using thread::LinearCombination, do not expect this to be stable + thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC> + , + typename detail::CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + CtaTileShape_MNK, + EpilogueTile, + ElementAccumulator, + AccLoadOp + >::Callbacks + >; + + using CollectiveOp = cute::conditional_t< + cute::is_same_v, + cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::Sm100NoSmemWarpSpecialized, + EpilogueTile, + ElementC_, + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + FusionCallbacks, + AccLoadOp, + AlignmentCType, + AlignmentDType + >, + cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized, + EpilogueTile, + ElementC_, + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + FusionCallbacks, + AccLoadOp + > + >; +}; + +// No smem builder for OpClassBlockScaledTensorOp +template < + class CtaTileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassBlockScaledTensorOp, + CtaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp, + cute::enable_if_t || + cute::is_same_v >> { + + static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); + + static constexpr bool DisableSource = cute::is_void_v; + using ElementC = cute::conditional_t; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t; + static constexpr thread::ScaleType::Kind ScaleType = DisableSource ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + static_assert(cute::is_tuple::value || cute::is_same_v); + using EpilogueTile = cute::conditional_t, + cute::Shape<_128, _64>, + EpilogueTileType + >; + + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_1>{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + + using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized; + + using AlignmentCType = Int; + using AlignmentDType = Int; + + static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + + static_assert(is_base_of_v, "only support EVT fusions"); + using FusionCallbacks = + typename detail::CallbacksBuilder< + DispatchPolicy, + FusionOp, + CtaTileShape_MNK, + EpilogueTile, + ElementAccumulator, + AccLoadOp + >::Callbacks; + + using CollectiveOp = cute::conditional_t< + cute::is_same_v, + cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::Sm100NoSmemWarpSpecialized, + EpilogueTile, + ElementC_, + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + FusionCallbacks, + AccLoadOp, + AlignmentCType, + AlignmentDType + >, + cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized, + EpilogueTile, + ElementC_, + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + FusionCallbacks, + AccLoadOp + > + >; +}; + +// TMA epilogue builder +template < + class OpClass, + class CtaTileShape_MNK, // Static CTA tile shape + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm100, + OpClass, + CtaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp, + cute::enable_if_t< + // OpClass + ( cute::is_same_v + || cute::is_same_v + ) && + // Epilogue Schedule Type + ( cute::is_base_of_v || + cute::is_base_of_v + || detail::IsPtrArrayDispatchPolicy + )>> + { +private: + using TmemWarpShape_MN = cute::conditional_t(CtaTileShape_MNK{}) == 64 && + (cute::is_base_of_v + || cute::is_same_v + ), + Shape<_2,_2>, Shape<_4,_1>>; + +public: + using CollectiveOp = + typename detail::Sm100TmaBuilderImpl< + OpClass, + CtaTileShape_MNK, + EpilogueTileType, + TmemWarpShape_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp + >::CollectiveOp; +}; + +// Auto builder +template < + class OpClass, + class CtaTileShape_MNK, // Static CTA tile shape + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm100, + OpClass, + CtaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp, + cute::enable_if_t< + // OpClass + ( cute::is_same_v + || cute::is_same_v + ) + // Epilogue Schedule Type + && cute::is_same_v> +> + { +private: + static_assert(cute::is_same_v, "Don't specify epilogue tile with auto schedule"); + using TmemWarpShape_MN = cute::conditional_t(CtaTileShape_MNK{}) == 64 && + size<0>(ClusterShape_MNK{}) % 2 == 0 + , + Shape<_2,_2>, Shape<_4,_1>>; +public: + using CollectiveOp = + typename detail::Sm100TmaBuilderImpl< + OpClass, + CtaTileShape_MNK, + EpilogueTileType, + TmemWarpShape_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp + >::CollectiveOp; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 1392428040..0f11bc34eb 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -167,7 +167,7 @@ sm90_compute_tile_shape_or_override() { } } -// callbacks builder with TMA aux out +// aux fusion callbacks builder for sm90 tma epilogue template < int StagesC, int StagesD, @@ -177,6 +177,7 @@ template < class FusionOp, class TileShape_MNK, class EpilogueTile_MN, + class AccLoadOp, class ElementAccumulator > struct CallbacksBuilder< @@ -185,8 +186,9 @@ struct CallbacksBuilder< TileShape_MNK, EpilogueTile_MN, ElementAccumulator, + AccLoadOp, cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor - && not cute::is_subbyte_v> + && not cute::is_subbyte_v> // aux subbyte tensor doesn't use smem > { using GmemStrideTypeAux = gemm::TagToStrideC_t; using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< @@ -213,6 +215,7 @@ template < class FusionOp, class TileShape_MNK, class EpilogueTile_MN, + class AccLoadOp, class ElementAccumulator > struct CallbacksBuilder< @@ -221,6 +224,7 @@ struct CallbacksBuilder< TileShape_MNK, EpilogueTile_MN, ElementAccumulator, + AccLoadOp, cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor && sizeof_bits_v == 1> > { diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index f953930298..10643c236f 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -48,7 +48,6 @@ struct EpilogueTileAuto {}; // Used to let the builder pick the epilogue schedule automatically. // Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp struct EpilogueScheduleAuto {}; -struct EpilogueIm2ColScheduleAuto {}; template < class ArchTag, @@ -83,6 +82,7 @@ template< class TileShape_MNK, class EpilogueTile_MN, class ElementAccumulator, + class AccLoadOp = cute::DefaultCopy, class = void > struct CallbacksBuilder { @@ -95,6 +95,7 @@ template < class FusionCallbacks, class TileShape_MNK, class EpilogueTile_MN, + class AccLoadOp, class ElementAccumulator > struct CallbacksBuilder< @@ -103,6 +104,7 @@ struct CallbacksBuilder< TileShape_MNK, EpilogueTile_MN, ElementAccumulator, + AccLoadOp, cute::enable_if_t> > { using Callbacks = FusionCallbacks; @@ -117,4 +119,5 @@ struct CallbacksBuilder< ///////////////////////////////////////////////////////////////////////////////////////////////// #include "builders/sm90_builder.inl" +#include "builders/sm100_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 4a6e558b68..918017efa4 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -65,6 +65,10 @@ class CollectiveEpilogue { #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" #include "sm90_epilogue_array_tma_warpspecialized.hpp" +#include "sm100_epilogue_nosmem.hpp" +#include "sm100_epilogue_array_nosmem.hpp" +#include "sm100_epilogue_tma_warpspecialized.hpp" +#include "sm100_epilogue_array_tma_warpspecialized.hpp" // // Conv // diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index 562f77242a..45ebd184f2 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -165,9 +165,9 @@ class DefaultEpilogue { BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, + [[maybe_unused]] ResidueMNK, int thread_idx, - [[maybe_unused]] char* smem_buf) + [[maybe_unused]] char*) { using namespace cute; using X = Underscore; @@ -186,20 +186,20 @@ class DefaultEpilogue { auto stride_d = detail::get_epilogue_stride(params.dD); // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), @@ -207,15 +207,16 @@ class DefaultEpilogue { CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); - // Make an identity coordinate tensor for predicating our output MN tile - auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor tCcD = thr_mma.partition_C(cD); + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(blk_shape_MNK), make_coord(m_coord, n_coord)); // (BLK_M,BLK_N) + Tensor tCcD = thr_mma.partition_C(cD_mn); // (VEC,THR_M,THR_N) // source is needed if (epilogue_op.is_source_needed()) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + if (elem_less(tCcD(i), make_shape(M,N))) { tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); } } @@ -224,7 +225,7 @@ class DefaultEpilogue { else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + if (elem_less(tCcD(i), make_shape(M,N))) { tCgD(i) = epilogue_op(accumulators(i)); } } diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp index e4d0fc89c3..3cab46ddcf 100644 --- a/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -169,9 +169,9 @@ class DefaultEpilogueArray { BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, + [[maybe_unused]] ResidueMNK, int thread_idx, - [[maybe_unused]] char* smem_buf) + [[maybe_unused]] char*) { using namespace cute; using X = Underscore; @@ -230,18 +230,18 @@ class DefaultEpilogueArray { if (epilogue_op.is_source_needed()) { ptr_C_l = params.ptr_C[l_coord]; } - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), @@ -249,15 +249,16 @@ class DefaultEpilogueArray { CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); - // Make an identity coordinate tensor for predicating our output MN tile - auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor tCcD = thr_mma.partition_C(cD); + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(blk_shape_MNK), make_coord(m_coord, n_coord)); // (BLK_M,BLK_N) + Tensor tCcD = thr_mma.partition_C(cD_mn); // (VEC,THR_M,THR_N) // source is needed if (epilogue_op.is_source_needed()) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + if (elem_less(tCcD(i), make_shape(M,N))) { tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); } } @@ -266,7 +267,7 @@ class DefaultEpilogueArray { else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + if (elem_less(tCcD(i), make_shape(M,N))) { tCgD(i) = epilogue_op(accumulators(i)); } } diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index d5194064cf..5d9d6817d9 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -208,6 +208,25 @@ struct IsThreadEpilogueOpWithElementwiseArguments< ThreadEpilogueOp, cute::void_t> : cute::true_type {}; + +// Check if ActivationFn has 'Arguments' type defined +template +struct sm100_act_has_arguments : cute::false_type {}; + +template +struct sm100_act_has_arguments > : cute::true_type {}; + +template +struct Sm100EpilogueOpNumAccumulatorMtxs { + static constexpr int value = 1; +}; + +template +struct Sm100EpilogueOpNumAccumulatorMtxs> { + static constexpr int value = EpilogueOp::NumAccumulatorMtxs; +}; + + // Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels template class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { @@ -465,6 +484,337 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } }; + +// Wrapper class to use operator-style epilogues in sm100 TMA warp-specialized kernels +template +class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { +public: + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; // 0 stage to disable smem alloc + using LoadPipelineState = cutlass::PipelineState<0>; + + using StorePipeline = cutlass::PipelineTmaStore<1>; // tma store pipe has no smem alloc + using StorePipelineState = cutlass::PipelineState<1>; + + using TensorStorage = typename EpilogueOp::SharedStorage; + using TensorMapStorage = typename EpilogueOp::SharedStorage; + using PipelineStorage = typename LoadPipeline::SharedStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + static constexpr int NumAccumulatorMtxs = Sm100EpilogueOpNumAccumulatorMtxs::value; + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK) { + return 1; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK) { + return 1; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { + } + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return false; + } + + // ctor inheritance + using EpilogueOp::EpilogueOp; + + CUTLASS_DEVICE auto + load_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] int32_t const sm_count, + [[maybe_unused]] int32_t const sm_idx) const { + return cute::make_tuple(nullptr); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + bool reverse_epi_n = false) + { + // C load is performed in epilogue operator + return load_pipe_producer_state; + } + + // with Tensormap + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileShapeMNK, + class CtaTileCoordMNKL, + class MmaTileMNK, + class TiledMma, + class TensorMap + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileShapeMNK tile_shape_mnk, + CtaTileCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + [[maybe_unused]] cute::tuple const& load_tensormap_info, + bool reverse_epi_n = false) + { + // C load is performed in epilogue operator + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) + { + } + + CUTLASS_DEVICE auto + store_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] int32_t const sm_count, + [[maybe_unused]] int32_t const sm_idx) const { + return cute::make_tuple(nullptr); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors + ) + { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + + auto [acc_state_next] = (*this).template operator()( + acc_pipeline, + acc_pipe_consumer_state, + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + shared_tensors); + + // Let mma warp know tmem buffer is consumed and empty + ++load_pipe_consumer_state; + ++store_pipe_producer_state; + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); + } + + // FastF32 API + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TiledCopyT2R tiled_t2r) + { + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tTR_rAcc, + shared_tensors, + tiled_t2r); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + // FastF32 API with Tensor Map + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R, + class TensorMap + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TensorMap tensormap, + TiledCopyT2R tiled_t2r) { + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tTR_rAcc, + shared_tensors, + tiled_t2r); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class TileCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TensorMap + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + TileCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors, + TensorMap tensormap + ) + { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + + auto [acc_state_next] = (*this).template operator()( + acc_pipeline, + acc_pipe_consumer_state, + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + shared_tensors); + + // Let mma warp know tmem buffer is consumed and empty + ++load_pipe_consumer_state; + ++store_pipe_producer_state; + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); + } + + template + CUTLASS_DEVICE void + store_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state, + [[maybe_unused]] CtaTileMNK cta_tile_mnk) + { + } + + // Dummy methods to perform different parts of TMA/Tensormap modifications + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] ProblemShape problem_shape, + [[maybe_unused]] int32_t next_batch) { } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] cute::TmaDescriptor const* tensormap) { } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } +}; + + // SFINAE helpers for detecting beta/beta_ptr/beta_ptr_array in EVT arguments. template struct has_beta { diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp new file mode 100644 index 0000000000..c1b06b06d9 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by Ptr-Array and Grouped GEMM epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_ +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100PtrArrayNoSmem; + using EpilogueTile = EpilogueTile_; + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpT2R = CopyOpT2R_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + constexpr static int ThreadCount = 128; + constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; + constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + struct TensorStorage { }; + struct TensorMapStorage { }; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&) { + + using namespace cute; + using X = Underscore; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + // Detect interleaved complex fp32 kernels + Tensor accs = accumulators; + using ElementTmem = typename decltype(accs)::value_type; + constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; + + // 1. Load accumulators into register from tmem + // Tmem -> rmem and transformation for interleaved complex kernels + if constexpr (is_interleaved_complex_f32) { + using ElementComputeAccumulator = float; + + Tensor tAccReal = accumulators(make_coord(_,_),_0{},_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAccImag = accumulators(make_coord(_,_),_0{},_0{},_1{}); // (CTA_M,CTA_N) + Tensor tTR_tAccReal = thread_t2r.partition_S(tAccReal); // (T2R,T2R_M,T2R_N) + Tensor tTR_tAccImag = thread_t2r.partition_S(tAccImag); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccReal = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccImag = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAccReal, tTR_rAccReal); + copy(tiled_t2r, tTR_tAccImag, tTR_rAccImag); + + // 1.1. Transform accumulators in registers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAccReal); i++) { + tTR_rAcc(i) = {tTR_rAccReal(i), tTR_rAccImag(i)}; + } + } + + // Standard tmem -> rmem epilogue + else { + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + } + + // 2. Apply element-wise operation and store to gmem + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rAcc(i), tTR_gC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rAcc(i)); + } + } + } + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + return cute::make_tuple(acc_pipe_consumer_state); + } + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledCopy + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor& tTR_rGlobAcc, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&, + TiledCopy tiled_t2r) { + + using namespace cute; + using X = Underscore; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + ThreadEpilogueOp epilogue_op{params.thread}; + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N) + + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide( cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + // 2. Apply element-wise operation and store to gmem + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i), tTR_gC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i)); + } + } + } + } + +protected: + Params const& params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For sm100 kernels requiring warp specialized epilogues +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_ +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmemWarpSpecialized, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_ +> : public detail::Sm100TmaWarpSpecializedAdapter> +{ +public: + // ctor inheritance + using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp new file mode 100644 index 0000000000..0b00720839 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -0,0 +1,1190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing elementwise operations used by Ptr-Array and Grouped Gemm epilogue. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileShape_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm100PtrArrayTmaWarpSpecialized, + CtaTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100PtrArrayTmaWarpSpecialized; + using CtaTileShape = CtaTileShape_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpT2R = CopyOpT2R_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + constexpr static int ThreadCount = 128; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + +private: + using GmemElementD = ElementD; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + static_assert(StagesC >= 1, "StagesC must be >= 1"); + static_assert(StagesD >= 1, "StagesD must be >= 1"); + + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static bool is_source_supported = not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + using SmemLayoutStageC = decltype(tile_to_shape(SmemLayoutAtomC{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayoutStageD = decltype(tile_to_shape(SmemLayoutAtomD{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + + constexpr static int StageCBits = cosize_v * sizeof_bits_v; + constexpr static int StageDBits = cosize_v * sizeof_bits_v; + constexpr static int MaxStageBits = cute::max(StageCBits, StageDBits); + constexpr static int StrideStageC = (ReuseSmemC ? MaxStageBits : StageCBits) / sizeof_bits_v; + constexpr static int StrideStageD = (ReuseSmemC ? MaxStageBits : StageDBits) / sizeof_bits_v; + + using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); + using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); + + constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + && MaxStageBits % sizeof_bits_v == 0 + && MaxStageBits % sizeof_bits_v == 0; + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = StageCBits / 8; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_C; + cute::TmaDescriptor smem_tensormap_D; + } tensormaps; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 1; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + struct Params { + using TensorShapeC = decltype(repeat_like(append<3>(StrideC{}, _1{}), int32_t(0))); + using TensorShapeD = decltype(repeat_like(append<3>(StrideD{}, _1{}), int32_t(0))); + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor( + make_gmem_ptr(static_cast,ElementD,ElementC> const*>(nullptr)), + TensorShapeC{}, + append<3>(InternalStrideC{}, _0{})), + SmemLayoutStageC{}, + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + TensorShapeD{}, + append<3>(InternalStrideD{}, _0{})), + SmemLayoutStageD{}, + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + cute::TmaDescriptor* tensormaps; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // + // Gemm Host Functions + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_L = 1; + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_c = InternalStrideC{}; + stride_d = InternalStrideD{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + init_M = get<0>(problem_shape_MNKL); + init_N = get<1>(problem_shape_MNKL); + + stride_c = args.dC; + stride_d = args.dD; + } + + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + // Tensor pointers will be fixed before the first access + ElementC const* ptr_C_first_batch = nullptr; + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); + tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, EpilogueTile{}, _1{}); + } + + // Tensor pointers will be fixed before the first access + ElementD* ptr_D_first_batch = nullptr; + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + typename Params::TMA_D tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); + + auto fusion_workspace = static_cast(workspace); + auto fusion_workspace_size = FusionCallbacks::get_workspace_size(problem_shape, args.thread); + auto tma_descriptor_workspace = reinterpret_cast( + static_cast(workspace) + fusion_workspace_size); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), + tma_load_c, + tma_store_d, + tma_descriptor_workspace, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + bool implementable = true; + bool fusion_implementable = true; + + if (problem_shape.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + + if constexpr (is_source_supported) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if (cute::is_void_v || args.ptr_C == nullptr) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + if constexpr (detail::has_beta_ptr_array::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + // + // Static Device Functions + // + + template + CUTLASS_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(cta_tile_mnk)), EpilogueTile{})); + } + + template + CUTLASS_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + return get_load_pipe_increment(cta_tile_mnk); + } + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + + // + // Non-static Device Functions + // +public: + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + CUTLASS_DEVICE auto + load_init( + Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = true; + auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + return cute::make_tuple(load_tensormap); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class TensorMapC + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + cute::tuple load_tensormap_info, + bool reverse_epi_n = false) { + using namespace cute; + + // Check to see if tensormaps have been replaced in gmem + if (get<1>(load_tensormap_info) /* did_batch_change */) { + tensormaps_fence_acquire(get<0>(load_tensormap_info)); + } + + int lane_idx = canonical_lane_idx(); + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(cta_tile_mnk)); + Tensor gC = local_tile(mC, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + lane_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gC_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gC_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gC_epi) - 1 - iter_n; + } + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(get<0>(load_tensormap_info), *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_init( + Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = false; + cute::TmaDescriptor* store_tensormap = nullptr; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + // Only the first epilogue warp needs to perform TMA related operations + if (warp_idx == 0) { + store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + } + return cute::make_tuple(store_tensormap); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors, + cute::tuple store_tensormap_info + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(accumulators) == 3, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); + static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // Check to see if tensormaps have been replaced in gmem + // Only the first epilogue warp needs to perform TMA related operations + if (get<1>(store_tensormap_info) /* did_batch_change */ && warp_idx == 0) { + tensormaps_fence_acquire(get<0>(store_tensormap_info)); + } + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + // Does directly store the visitor into smem. + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tTR_rAcc = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + Tensor tRR_rD_src = thread_r2r.retile_S(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_rD = [&]() { + if constexpr (!IsDirectR2S) { + return make_tensor(shape(tRS_sD(_,_,_,_0{}))); + } + else{ + return thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + } + }(); + + Tensor tRR_rD_dst_frg = recast>(coalesce(tRR_rD_dst)); + Tensor tRS_rD_frg = recast>(coalesce(tRS_rD)); + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_counting_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for sub-128 thread T2R tiled copy + Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; + constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); + bool issue_tmem_load = true; + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + auto epi_loop_fn = [&] (auto& cst_callbacks) { + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + // Begin the wait for the accumulator results + ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gD_epi) - 1 - iter_n; + } + do_acc_release = iter_m == size<2>(gD_epi)-1 && iter_n == 0; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if (is_first_iteration) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); + } + + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + // Compute tmem load predication if necessary + if constexpr (predicate_tmem_load) { + // Issue tmem load if this tile's tmem subpartition is accessible by this warp + int subpart_idx = (tTR_tAcc_mn.data().dp_ / 32) % 4; + issue_tmem_load = warp_idx == subpart_idx; + } + + // Copy accumulator tile from tmem to register + if (issue_tmem_load) { + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + if constexpr (!IsDirectR2S) { + // At present, only FP4 col output with scalefactor generation fusion would go into these branch + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + tRS_rD_frg(_0{}) = cutlass::NumericArrayConverter{}(tRR_rD_dst_frg(_0{})); + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + + // Copy output tile from register to smem + bool issue_smem_store = issue_tmem_load; + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + }; + + epi_loop_fn(cst_callbacks); + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); + } + + template + CUTLASS_DEVICE void + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + CtaTileMNK cta_tile_mnk) { + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init(Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* tma_desc = nullptr; + cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + if constexpr (IsLoad) { + if (is_source_supported) { + tma_desc = &gmem_tensormap[sm_idx]; + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{}); + copy(recast(pC_tensormap), recast(sC_tensormap)); + } + __syncwarp(); + } + } else { + int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; + tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{}); + copy(recast(pD_tensormap), recast(sD_tensormap)); + } + __syncwarp(); + } + + return tma_desc; + } + + // Replace address for the global tensor (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& params, + int32_t next_batch) { + // Replacing global_address for the next batch + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.ptr_C != nullptr) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C, + params.ptr_C[next_batch]); + } + } + } else { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, + params.ptr_D[next_batch]); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape = {1,1,1,1,1}; + cute::array prob_stride = {0,0,0,0,0}; + + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.dC != nullptr) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); + } + } + } + else { + ElementD const* ptr_D = nullptr; + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D, + prob_shape, + prob_stride); + } + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& params, + cute::TmaDescriptor const* tensormap, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties( + shared_tensormap, params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormap, tensormap); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormap, + cute::TmaDescriptor const* tensormap) { + // Entire warp must do this (ie its aligned) + if constexpr (IsLoad) { + if (is_source_supported) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); + } + } else { + tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { + if constexpr (IsLoad) { + if (is_source_supported) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } else { + cute::tma_descriptor_fence_acquire(tensormap); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp new file mode 100644 index 0000000000..85ede7031c --- /dev/null +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -0,0 +1,819 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/helper_macros.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +template +struct IsDefaultFusionOp { + static constexpr bool value = false; +}; + +template< + class ElementD, class ElementCompute, + class ElementC, FloatRoundStyle RoundStyle +> +struct IsDefaultFusionOp< + epilogue::fusion::LinearCombination< + ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle> +> { + static constexpr bool value = true; +}; + +template< + class ElementOutput, int Count, class ElementAccumulator, + class ElementCompute, epilogue::thread::ScaleType::Kind Scale, + FloatRoundStyle Round, class ElementSource +> +struct IsDefaultFusionOp< + epilogue::thread::LinearCombination< + ElementOutput, Count, ElementAccumulator, + ElementCompute, Scale, Round, ElementSource> +> { + static constexpr bool value = true; +}; + +// Legacy direct store sm100 epilogue using thread::LinearCombination, do not expect this to be stable +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100NoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100NoSmem; + using EpilogueTile = EpilogueTile_; + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpT2R = CopyOpT2R_; + using AlignmentC = AlignmentC_; + using AlignmentD = AlignmentD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + constexpr static int ThreadCount = 128; + constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; + constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; + +protected: + Params const& params; + + // + // Non-static Device Methods + // +public: + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&) { + + using namespace cute; + using X = Underscore; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); + auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; + constexpr int VD = cute::min(AlignmentD{}, size(mclD)); + Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); + Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); + + Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); + Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tDpD); t++) { + tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); + } + + constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; + constexpr int VC = cute::min(AlignmentC{}, size(mclC)); + + Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); + Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); + Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tCpC); t++) { + tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); + } + Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); + Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); + + // Detect interleaved complex fp32 kernels + [[maybe_unused]] Tensor accs = accumulators; + using ElementTmem = typename decltype(accs)::value_type; + constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; + + // 1. Load accumulators into register from tmem + // Tmem -> rmem and transformation for interleaved complex kernels + if constexpr (is_interleaved_complex_f32) { + using ElementComputeAccumulator = float; + + Tensor tAccReal = accumulators(make_coord(_,_),_0{},_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAccImag = accumulators(make_coord(_,_),_0{},_0{},_1{}); // (CTA_M,CTA_N) + Tensor tTR_tAccReal = thread_t2r.partition_S(tAccReal); // (T2R,T2R_M,T2R_N) + Tensor tTR_tAccImag = thread_t2r.partition_S(tAccImag); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccReal = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccImag = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAccReal, tTR_rAccReal); + copy(tiled_t2r, tTR_tAccImag, tTR_rAccImag); + + // 1.1. Transform accumulators in registers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAccReal); i++) { + tTR_rAcc(i) = {tTR_rAccReal(i), tTR_rAccImag(i)}; + } + } + + // Standard tmem -> rmem epilogue + else { + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + } + + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + + // 2. Apply element-wise operation and store to gmem + ThreadEpilogueOp epilogue_op{params.thread}; + // source is needed + if (epilogue_op.is_source_needed()) { + copy_if(tCpC, tTR_rC_src, tTR_rC_dst); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); + } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); + } + // source is not needed, avoid load + else + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); + } + + return cute::make_tuple(acc_pipe_consumer_state); + } + + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledCopy + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor& tTR_rGlobAcc, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&, + TiledCopy tiled_t2r) { + + using namespace cute; + using X = Underscore; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); + auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N) + + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide( cCD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + // 2. Apply element-wise operation and store to gmem + ThreadEpilogueOp epilogue_op{params.thread}; + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cCD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i), tTR_gC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cCD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i)); + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Direct store sm100 epilogue supporting EVT +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100NoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + // Required by the gemm::kernel + using DispatchPolicy = Sm100NoSmem; + using ElementC = ElementC_; + using ElementD = ElementD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using StrideC = StrideC_; + using StrideD = StrideD_; + using EpilogueTile = EpilogueTile_; + using CopyOpT2R = CopyOpT2R_; + using FusionCallbacks = FusionCallbacks_; + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + +private: + constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported + || is_same_v; // alloc reduction buffer for custom EVTs + constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + +public: + constexpr static int ThreadCount = 128; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + array_aligned buffer; + }; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC = {}; + ElementD* ptr_D = nullptr; + StrideD dD = {}; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC = {}; + ElementD* ptr_D = nullptr; + StrideD dD = {}; + }; + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) + : fusion_callbacks(params_.thread, shared_tensors.thread) + , smem_buffer_ptr(shared_tensors.buffer.data()) + , params(params_) {}; + +protected: + FusionCallbacks fusion_callbacks; + uint8_t* smem_buffer_ptr; + Params const& params; + +public: + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + return fusion_implementable; + } + + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + [[maybe_unused]]SharedStorage& + ) { + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + // Wait for mma warp to fill tmem buffer with accumulator results + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); + + auto [M, N, K, L] = problem_shape_mnkl; + auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); + auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); + auto cta_tiler = take<0,2>(cta_tile_mnk); + + int thread_idx = threadIdx.x % ThreadCount; + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + + constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); + Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); + + // Construct the EVT consumer callbacks + auto residue_cD = make_coord(M,N) - cD(_0{}); + auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); + Tensor cD_ = make_counting_tensor(cD.layout()); + Tensor tTR_cD_ = make_counting_tensor(tTR_cD.layout()); + constexpr bool RefSrc = false; + + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); + + Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mC, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); + + Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D)), make_shape(M,N,L), params.dD); + + Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mD, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); + + // Register Tensor + Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); + + Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); + Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + coord_cCD, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); + constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; + constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); + + auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); + constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; + constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); + + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); + + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + int(0), + EpilogueTile{}, + tiled_t2r, + cD_, + residue_cD, + tTR_cD_, + residue_tTR_cD, + tCrC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. + synchronize(); + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<4>(tTR_tAcc); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<3>(tTR_tAcc); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + + bool is_last_iteration = iter_m == size<3>(tTR_tAcc)-1 && iter_n == size<4>(tTR_tAcc)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<4>(tTR_tAcc) - 1 - iter_n; + } + do_acc_release = iter_m == size<3>(tTR_tAcc)-1 && iter_n == 0; + } + + Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_C_load_needed) { + Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int{}))); + Tensor tTR_gC_frg = recast>(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); + Tensor tTR_rC_frg = recast>(coalesce(tCrC)); + + auto pred_fn_C = [&] (auto const&... coords) { + return elem_less(tTR_cC_frag(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg); + } + + // Copy accumulator tile from tmem to register + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); + + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + Tensor reduction_buffer = make_tensor( + raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); + + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); + + cst_callbacks.end_loop(epi_m, epi_n); + + + Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int{}))); + + using VecType = uint_bit_t>; + Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); + Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); + + auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { + return elem_less(tTR_cD_frag(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg); + + } // for epi_m + } // for epi_n + + cst_callbacks.end(); + }; + + epi_loop_fn(cst_callbacks); + return cute::make_tuple(acc_pipe_consumer_state); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// For sm100 kernels requiring warp specialized epilogues +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100NoSmemWarpSpecialized, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_ +> : public detail::Sm100TmaWarpSpecializedAdapter> +{ +public: + // ctor inheritance + using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000..a144accd92 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,1289 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/trace.h" + +#include "cutlass/conv/detail.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileShape_, // (CTA_M,CTA_N,CTA_K, optional: Tile_L) + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm100TmaWarpSpecialized, + CtaTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100TmaWarpSpecialized; + using CtaTileShape = CtaTileShape_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpT2R = CopyOpT2R_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + constexpr static int ThreadCount = 128; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + +private: + using GmemElementD = ElementD; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + static_assert(StagesC >= 1, "StagesC must be >= 1"); + static_assert(StagesD >= 1, "StagesD must be >= 1"); + + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static bool is_source_supported = not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + using SmemLayoutStageC = decltype(tile_to_shape(SmemLayoutAtomC{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayoutStageD = decltype(tile_to_shape(SmemLayoutAtomD{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + + constexpr static int StageCBits = cosize_v * sizeof_bits_v; + constexpr static int StageDBits = cosize_v * sizeof_bits_v; + constexpr static int MaxStageBits = cute::max(StageCBits, StageDBits); + constexpr static int StrideStageC = (ReuseSmemC ? MaxStageBits : StageCBits) / sizeof_bits_v; + constexpr static int StrideStageD = (ReuseSmemC ? MaxStageBits : StageDBits) / sizeof_bits_v; + + using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); + using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); + + constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + && MaxStageBits % sizeof_bits_v == 0 + && MaxStageBits % sizeof_bits_v == 0; + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = StageCBits / 8; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 1; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + +private: + static constexpr auto + get_tma_epi_tile() { + return cute::transform_apply(EpilogueTile{}, seq<0,1>{}, + [] (auto epi_tiler, auto mode) { + auto cta_tiler_shape = get(CtaTileShape{}); + // Use a dynamic stride to prevent mode coalescing + auto cta_tiler_stride = repeat_like(cta_tiler_shape, 0); + auto cta_tiler = make_layout(cta_tiler_shape, cta_tiler_stride); + // This is a multimodal CTA tiler, transform before returning + if constexpr (depth(cta_tiler) > 0) { + // This is an implicit multimodal tiler, match profile and return + if constexpr (tuple_size_v == 1) { + return make_tile(epi_tiler); + } + // This is an explicit multimodal tiler, compose out epi tiler + else { + return shape(composition(cta_tiler, epi_tiler)); + } + } + // This is a flat CTA tiler, no need for transformation + else { + return epi_tiler; + } + }, + [] (auto... epi_tilers) { + return make_tile(epi_tilers...); + } + ); + } + + using TmaEpilogueTile = decltype(get_tma_epi_tile()); + + template + static constexpr auto + get_tma_load_c(ProblemShapeMNL const& problem_shape_mnl, Arguments const& args) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), + make_layout(problem_shape_mnl, append<3>(args.dC, _0{}))); + return make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, TmaEpilogueTile{}, _1{}); + } + + template + static constexpr auto + get_tma_store_d(ProblemShapeMNL const& problem_shape_mnl, Arguments const& args) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), + make_layout(problem_shape_mnl, append<3>(args.dD, _0{}))); + return make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, TmaEpilogueTile{}, _1{}); + } + +public: + // Device side epilogue params + struct Params { + using TMA_C = decltype(get_tma_load_c (repeat_like(append<3>(StrideC{},_1{}), int32_t(0)), Arguments{})); + using TMA_D = decltype(get_tma_store_d(repeat_like(append<3>(StrideD{},_1{}), int32_t(0)), Arguments{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + }; + + // + // Gemm Host Functions + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnl = select<0,1,3>(append<4>(problem_shape, 1)); + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + tma_load_c = get_tma_load_c(problem_shape_mnl, args); + } + + typename Params::TMA_D tma_store_d = get_tma_store_d(problem_shape_mnl, args); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_d = cutlass::detail::get_output_alignment_bits(); + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_d / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); + } + + if constexpr (is_source_supported) { + constexpr int tma_alignment_bits_c = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_c / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + // + // Conv Host Functions + // + + template + static constexpr Params + to_underlying_arguments(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return to_underlying_arguments(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args, workspace); + } + + template + static size_t + get_workspace_size(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return get_workspace_size(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + + template + static cutlass::Status + initialize_workspace(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args, + void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return initialize_workspace(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + + // + // Static Device Functions + // + + template + CUTLASS_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(cta_tile_mnk)), EpilogueTile{})); + } + + template + CUTLASS_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + return get_load_pipe_increment(cta_tile_mnk); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + + // + // Non-static Device Functions + // +public: + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + bool reverse_epi_n = false) { + using namespace cute; + + int lane_idx = canonical_lane_idx(); + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = + conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(cta_tile_mnk)); + Tensor gC = local_tile(mC, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + lane_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gC_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gC_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gC_epi) - 1 - iter_n; + } + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(accumulators) == 3, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); + static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = + conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + // Does directly store the visitor into smem. + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tTR_rAcc = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + Tensor tRR_rD_src = thread_r2r.retile_S(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_rD = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (!IsDirectR2S) { + return make_tensor(shape(tRS_sD(_,_,_,_0{}))); + } + else{ + return thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + } + }(); + + Tensor tRR_rD_dst_frg = recast>(coalesce(tRR_rD_dst)); + Tensor tRS_rD_frg = recast>(coalesce(tRS_rD)); + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_counting_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for sub-128 thread T2R tiled copy + Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; + constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); + bool issue_tmem_load = true; + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + // Begin the wait for the accumulator results + ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); + + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gD_epi) - 1 - iter_n; + } + do_acc_release = iter_m == size<2>(gD_epi)-1 && iter_n == 0; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if (is_first_iteration) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); + } + + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + // Compute tmem load predication if necessary + if constexpr (predicate_tmem_load) { + // Issue tmem load if this tile's tmem subpartition is accessible by this warp + int subpart_idx = (tTR_tAcc_mn.data().dp_ / 32) % 4; + issue_tmem_load = warp_idx == subpart_idx; + } + bool issue_smem_store = issue_tmem_load; + + // Copy accumulator tile from tmem to register + if (issue_tmem_load) { + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + if constexpr (!IsDirectR2S) { + // At present, only FP4 col output with scalefactor generation fusion would go into these branch + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + tRS_rD_frg(_0{}) = cutlass::NumericArrayConverter{}(tRR_rD_dst_frg(_0{})); + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + + // Copy output tile from register to smem + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + }; + + epi_loop_fn(cst_callbacks); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); + } + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + TensorStorage& shared_tensors, + TiledCopyT2R tiled_t2r + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = + conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_counting_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + Tensor tTR_rAcc_epi_tile = tTR_rAcc(_,_,_,epi_m,epi_n); + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc_epi_tile)); // (EPI_V) + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rD_frg); + + // Copy output tile from register to smem + bool issue_smem_store = true; + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template + CUTLASS_DEVICE void + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + CtaTileMNK cta_tile_mnk) { + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 4b18040a6e..be1ff675a3 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -51,9 +51,14 @@ struct EpiloguePtrArraySimtVectorized {}; struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecializedTransposed {}; -struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; + +struct TmaWarpSpecialized1Sm {}; +struct TmaWarpSpecialized2Sm {}; +struct PtrArrayTmaWarpSpecialized1Sm {}; +struct PtrArrayTmaWarpSpecialized2Sm {}; + struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; @@ -191,6 +196,46 @@ struct Sm90TmaWarpSpecializedBiasElementwise { constexpr static int FragmentSize = FragmentSize_; }; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm100TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm100PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + static_assert(StagesC >= 1, "StagesC must be >= 1"); + static_assert(StagesD >= 1, "StagesD must be >= 1"); +}; + +// default elementwise operator epilogue without smem +struct Sm100NoSmem {}; +struct Sm100NoSmemWarpSpecialized {}; +struct Sm100PtrArrayNoSmem {}; +struct Sm100PtrArrayNoSmemWarpSpecialized {}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/fusion/callbacks.hpp b/include/cutlass/epilogue/fusion/callbacks.hpp index c89db7f8fb..f9febeec4d 100644 --- a/include/cutlass/epilogue/fusion/callbacks.hpp +++ b/include/cutlass/epilogue/fusion/callbacks.hpp @@ -59,7 +59,8 @@ struct FusionCallbacks { template struct FusionCallbacksTraits { using DispatchPolicy = void; - using Operation = T; + using Callbacks = T; + using Operation = FusionOperation; using CtaTile_MNK = void; using EpilogueTile_MN = void; using ElementCompute = void; @@ -76,6 +77,7 @@ struct FusionCallbacksTraits< FusionCallbacks > { using DispatchPolicy = DispatchPolicy_; + using Callbacks = FusionCallbacks; using Operation = Operation_; using CtaTile_MNK = CtaTile_MNK_; using EpilogueTile_MN = EpilogueTile_MN_; diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index e1c53dac9d..156d358866 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -82,6 +82,12 @@ struct FusionOperation { using ElementAmax = void; static constexpr bool IsAbsMaxSupported = false; + + using ElementBlockScaleFactor = void; + static constexpr int SFVecSize = 0; + static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues + + using GmemLayoutTagScalefactor = void; }; // D = alpha * acc @@ -478,6 +484,140 @@ struct LinCombDeEltActDePerRowBias static constexpr bool IsDePerRowBiasSupported = true; }; + +template< + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombBlockScaleFactor + : LinearCombination { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + +// D = activation(alpha * acc + beta * C) +// With BlockScaleFactor generation (same recipe as LinCombBlockScaleFactor). +template< + template class ActivationFn_, + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombEltActBlockScaleFactor + : LinCombEltAct { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + +// D = alpha * acc + beta * C + per-row bias +// With BlockScaleFactor generation +template< + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasBlockScaleFactor + : LinCombPerRowBias { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +// D = alpha * acc + beta * C + per-col bias +// With BlockScaleFactor generation. +template< + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasBlockScaleFactor + : LinCombPerColBias { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +// D = activation(alpha * acc + beta * C + per-row bias) +// With BlockScaleFactor generation. +template< + template class ActivationFn_, + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltActBlockScaleFactor + : LinCombPerRowBiasEltAct { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +// D = activation(alpha * acc + beta * C + per-col bias) +// With BlockScaleFactor generation. +template< + template class ActivationFn_, + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltActBlockScaleFactor + : LinCombPerColBiasEltAct { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp new file mode 100644 index 0000000000..24972141a2 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,955 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm100 TMA warp-specialized (ws) epilogue +*/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" + +#include "cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Sm100 Tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// Sm100 direct store callbacks alias to sm100 tma callbacks with 0 stages +// Additional copy atom args will be ignored in the 0-stage specializations of aux load/store nodes +template < + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100NoSmemWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// Sm100 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +// With Row BlockScaleFactor Generation. +template< + int SFVecsize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinearCombRowBlockScaleFactor = + Sm90EVT, // gen scalefactor + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For Ptr-Array and Grouped GEMM +// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinearCombRowBlockScaleFactorPtrArray = + Sm90EVT, // gen scalefactor + Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + // NormConst is a single device-side constant value, its not per-batch or per-group + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For Ptr-Array and Grouped GEMM +// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombEltActRowBlockScaleFactorPtrArray = + Sm90EVT, // gen scalefactor + Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized, + fusion::LinCombEltActBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltActBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-row bias +// with row blockScaled generation +template< + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerRowBiasRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerRowBiasRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, + ElementScalar, + AlignmentBias, + RoundStyle + > +{ + + using Impl = + Sm100LinCombPerRowBiasRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per_col bias +// with row blockScaled generation +template< + int StagesC, + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerColBiasRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerColBiasRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm100LinCombPerColBiasRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// with row blockScaled generation +template< + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerRowBiasEltActRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, + ElementOutput, ElementCompute, + ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + + + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per_col bias) +// with row blockScaled generation +template< + int StagesC, + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerColBiasEltActRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, + ElementOutput, ElementCompute, + ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerColBiasEltActRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm100LinCombPerColBiasEltActRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp new file mode 100644 index 0000000000..a20591288a --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree compute operations for the sm100 TMA warp-specialized (ws) epilogue +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// BatchNormApply +// +// This node aims to do the batch norm apply. The procedure is described as follows: +// +// output = (input - mean) * inv_stddev * alpha + bias +// +// while: (1) input & output are 2 matrices with shape (M, N), +// which are frg_input & return value of the visit function +// +// (2) mean, inv_stddev, alpha & bias are 4 vectors with shape (N). +// which are loaded by ProducerLoadCallbacks +// +// To avoid redundant calculations in EVT, this node simplify the procedure as follows: +// +// output = input * alpha' + bias' +// +// while alpha' & bias' are 2 vectors with shape (N) calculated by mean, inv_stddev, alpha & bias +// +// The calculation among vectors is described as follows: +// +// alpha' = alpha * inv_stddev +// bias' = bias - mean * alpha' +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least + // this should just match CLC stage count + int Stages, + class CtaTileShapeMNK, + class ElementScalar, + class ElementCompute, + class ElementOutput, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm100BatchNormApply { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(cute::is_same_v>); // row vector broadcast for alpha, bias, mean & inv_stddev + + using SmemLayout = decltype(make_layout(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})))); + + using ElementCol = cute::conditional_t<(sizeof(ElementCompute) > sizeof(ElementScalar)), ElementCompute, ElementScalar>; + + struct SharedStorage { + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_alpha; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_bias; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_mean; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_inv_stddev; + }; + + struct Arguments { + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* bias_ptr = nullptr; + ElementScalar const* mean_ptr = nullptr; + ElementScalar const* inv_stddev_ptr = nullptr; + StrideMNL dVec = {}; + }; + + struct Params { + using TMA_Vec = decltype(make_tma_atom( + SM90_TMA_LOAD{}, + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), + take<0,2>(SmemLayout{}), + take<0,2>(CtaTileShapeMNK{}))); + + TMA_Vec tma_load_alpha; + TMA_Vec tma_load_bias; + TMA_Vec tma_load_mean; + TMA_Vec tma_load_inv_stddev; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + Tensor tensor_alpha = make_tensor(make_gmem_ptr(args.alpha_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + Tensor tensor_bias = make_tensor(make_gmem_ptr(args.bias_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + Tensor tensor_mean = make_tensor(make_gmem_ptr(args.mean_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + Tensor tensor_inv_stddev = make_tensor(make_gmem_ptr(args.inv_stddev_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + + typename Params::TMA_Vec tma_load_alpha = make_tma_atom(SM90_TMA_LOAD{}, tensor_alpha, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + typename Params::TMA_Vec tma_load_bias = make_tma_atom(SM90_TMA_LOAD{}, tensor_bias, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + typename Params::TMA_Vec tma_load_mean = make_tma_atom(SM90_TMA_LOAD{}, tensor_mean, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + typename Params::TMA_Vec tma_load_inv_stddev = make_tma_atom(SM90_TMA_LOAD{}, tensor_inv_stddev, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + + return Params{tma_load_alpha, tma_load_bias, tma_load_mean, tma_load_inv_stddev}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm100BatchNormApply() { } + + CUTLASS_HOST_DEVICE + Sm100BatchNormApply(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_alpha(const_cast(shared_storage.smem_alpha.data())), + smem_bias(const_cast(shared_storage.smem_bias.data())), + smem_mean(const_cast(shared_storage.smem_mean.data())), + smem_inv_stddev(const_cast(shared_storage.smem_inv_stddev.data())), + smem_col_alpha(const_cast(shared_storage.smem_alpha.data())), + smem_col_bias(const_cast(shared_storage.smem_bias.data())) { } + + Params const* params_ptr; + ElementScalar* smem_alpha; + ElementScalar* smem_bias; + ElementScalar* smem_mean; + ElementScalar* smem_inv_stddev; + ElementCompute* smem_col_alpha; + ElementCompute* smem_col_bias; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& gAlpha, GTensor&& gBias, GTensor&& gMean, GTensor&& gInvStddev, + STensor&& sAlpha, STensor&& sBias, STensor&& sMean, STensor&& sInvStddev, Params const* params_ptr) + : gAlpha(cute::forward(gAlpha)), + gBias(cute::forward(gBias)), + gMean(cute::forward(gMean)), + gInvStddev(cute::forward(gInvStddev)), + sAlpha(cute::forward(sAlpha)), + sBias(cute::forward(sBias)), + sMean(cute::forward(sMean)), + sInvStddev(cute::forward(sInvStddev)), + params_ptr(params_ptr) {} + + GTensor gAlpha; + GTensor gBias; + GTensor gMean; + GTensor gInvStddev; + + STensor sAlpha; + STensor sBias; + STensor sMean; + STensor sInvStddev; + + Params const* params_ptr; + + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + if (epi_m == 0 && epi_n == 0 && issue_tma_load) { + // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size + constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * bits_to_bytes(sizeof_bits_v) * 4; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA bulk copy + int pipe_index = (load_iteration / EpiTiles) % Stages; + copy(params_ptr->tma_load_alpha.with(*full_mbarrier_ptr), gAlpha, sAlpha(_,pipe_index)); + copy(params_ptr->tma_load_bias.with(*full_mbarrier_ptr), gBias, sBias(_,pipe_index)); + copy(params_ptr->tma_load_mean.with(*full_mbarrier_ptr), gMean, sMean(_,pipe_index)); + copy(params_ptr->tma_load_inv_stddev.with(*full_mbarrier_ptr), gInvStddev, sInvStddev(_,pipe_index)); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mAlpha = params_ptr->tma_load_alpha.get_tma_tensor(make_shape(size(M),N,size(L))); + Tensor mBias = params_ptr->tma_load_bias.get_tma_tensor(make_shape(size(M),N,size(L))); + Tensor mMean = params_ptr->tma_load_mean.get_tma_tensor(make_shape(size(M),N,size(L))); + Tensor mInvStddev = params_ptr->tma_load_inv_stddev.get_tma_tensor(make_shape(size(M),N,size(L))); + + Tensor gAlpha = local_tile(mAlpha, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gBias = local_tile(mBias, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gMean = local_tile(mMean, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gInvStddev = local_tile(mInvStddev, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + Tensor sBias = make_tensor(make_smem_ptr(smem_bias), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + Tensor sMean = make_tensor(make_smem_ptr(smem_mean), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + + auto [tCgAlpha, tCsAlpha] = tma_partition(params_ptr->tma_load_alpha, group_modes<0,2>(sAlpha), group_modes<0,2>(gAlpha)); + auto [tCgBias, tCsBias] = tma_partition(params_ptr->tma_load_bias, group_modes<0,2>(sBias), group_modes<0,2>(gBias)); + auto [tCgMean, tCsMean] = tma_partition(params_ptr->tma_load_mean, group_modes<0,2>(sMean), group_modes<0,2>(gMean)); + auto [tCgInvStddev, tCsInvStddev] = tma_partition(params_ptr->tma_load_inv_stddev, group_modes<0,2>(sInvStddev), group_modes<0,2>(gInvStddev)); + + constexpr int EpiTiles = decltype(size(ceil_div(shape(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ProducerLoadCallbacks( + cute::move(tCgAlpha), cute::move(tCgBias), cute::move(tCgMean), cute::move(tCgInvStddev), + cute::move(tCsAlpha), cute::move(tCsBias), cute::move(tCsMean), cute::move(tCsInvStddev), params_ptr); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + SR_RTensor&& tSR_rAlpha, SR_RTensor&& tSR_rBias, + SR_RTensor&& tSR_rMean, SR_RTensor&& tSR_rInvStddev, + SR_STensor&& tSR_sAlpha, SR_STensor&& tSR_sBias, + SR_STensor&& tSR_sMean, SR_STensor&& tSR_sInvStddev, + SR_CTensor&& tSR_cAlpha, + SR_SCTensor&& tSR_sColAlpha, SR_SCTensor&& tSR_sColBias, + RTensor&& tCrAlpha, RTensor&& tCrBias, + STensor&& tCsAlpha, STensor&& tCsBias, + ThrNum thr_num, + Params const* params_ptr) + : + tSR_rAlpha(cute::forward(tSR_rAlpha)), tSR_rBias(cute::forward(tSR_rBias)), + tSR_rMean(cute::forward(tSR_rMean)), tSR_rInvStddev(cute::forward(tSR_rInvStddev)), + tSR_sAlpha(cute::forward(tSR_sAlpha)), tSR_sBias(cute::forward(tSR_sBias)), + tSR_sMean(cute::forward(tSR_sMean)), tSR_sInvStddev(cute::forward(tSR_sInvStddev)), + tSR_cAlpha(cute::forward(tSR_cAlpha)), + tSR_sColAlpha(cute::forward(tSR_sColAlpha)), tSR_sColBias(cute::forward(tSR_sColBias)), + tCrAlpha(cute::forward(tCrAlpha)), tCrBias(cute::forward(tCrBias)), + tCsAlpha(cute::forward(tCsAlpha)), tCsBias(cute::forward(tCsBias)), + thr_num(thr_num), + params_ptr(params_ptr) {} + + SR_RTensor tSR_rAlpha; + SR_RTensor tSR_rBias; + SR_RTensor tSR_rMean; + SR_RTensor tSR_rInvStddev; + SR_STensor tSR_sAlpha; + SR_STensor tSR_sBias; + SR_STensor tSR_sMean; + SR_STensor tSR_sInvStddev; + SR_CTensor tSR_cAlpha; + SR_SCTensor tSR_sColAlpha; + SR_SCTensor tSR_sColBias; + + ThrNum thr_num; + + RTensor tCrAlpha; // (CPY,CPY_M,CPY_N) + RTensor tCrBias; // (CPY,CPY_M,CPY_N) + + STensor tCsAlpha; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + STensor tCsBias; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + + Params const* params_ptr; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if (epi_m == 0 && epi_n == 0) { // Assumes M-major subtile loop + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + int pipe_index = (load_iteration / EpiTiles) % Stages; + + Tensor tSR_rAlpha_flt = filter_zeros(tSR_rAlpha); + Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); + Tensor tSR_rMean_flt = filter_zeros(tSR_rMean); + Tensor tSR_rInvStddev_flt = filter_zeros(tSR_rInvStddev); + Tensor tSR_sAlpha_flt = filter_zeros(tSR_sAlpha(_,_,_,pipe_index)); + Tensor tSR_sBias_flt = filter_zeros(tSR_sBias(_,_,_,pipe_index)); + Tensor tSR_sMean_flt = filter_zeros(tSR_sMean(_,_,_,pipe_index)); + Tensor tSR_sInvStddev_flt = filter_zeros(tSR_sInvStddev(_,_,_,pipe_index)); + Tensor tSR_cAlpha_flt = filter_zeros(tSR_cAlpha, tSR_rAlpha.stride()); + + for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { + if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + // OOB of SMEM + continue; + } + tSR_rAlpha_flt(i) = tSR_sAlpha_flt(i); + tSR_rBias_flt(i) = tSR_sBias_flt(i); + tSR_rMean_flt(i) = tSR_sMean_flt(i); + tSR_rInvStddev_flt(i) = tSR_sInvStddev_flt(i); + } + + constexpr int RegFragSize = cute::min(size(tSR_rAlpha_flt), cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute)))); + Tensor tSR_rAlpha_frg = recast>(tSR_rAlpha_flt); // (FRG_V) + Tensor tSR_rBias_frg = recast>(tSR_rBias_flt); // (FRG_V) + Tensor tSR_rMean_frg = recast>(tSR_rMean_flt); // (FRG_V) + Tensor tSR_rInvStddev_frg = recast>(tSR_rInvStddev_flt); // (FRG_V) + + cutlass::multiplies> mul; + cutlass::negate> negate; + cutlass::multiply_add> mul_add; + + // We do computation among vectors before computation among matrices + // alpha' = alpha * inv_stddev + // bias' = bias - alpha' * mean + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAlpha_frg); ++i) { + tSR_rAlpha_frg(i) = mul(tSR_rAlpha_frg(i), tSR_rInvStddev_frg(i)); + tSR_rBias_frg(i) = mul_add(tSR_rAlpha_frg(i), negate(tSR_rMean_frg(i)), tSR_rBias_frg(i)); + } + + Tensor tSR_sColAlpha_flt = filter_zeros(tSR_sColAlpha(_,_,_,pipe_index)); + Tensor tSR_sColBias_flt = filter_zeros(tSR_sColBias(_,_,_,pipe_index)); + // After computation, 4 vectors -> 2 vectors + for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { + if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + // OOB of SMEM + continue; + } + tSR_sColAlpha_flt(i) = tSR_rAlpha_flt(i); + tSR_sColBias_flt(i) = tSR_rBias_flt(i); + } + + synchronize(); + + // To do bn_apply with Acc, reload these 2 vectors with the consistent shape + copy_aligned(tCsAlpha(_,_,_,_,_,pipe_index), tCrAlpha); + copy_aligned(tCsBias(_,_,_,_,_,pipe_index), tCrBias); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_inputs) { + constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); + cutlass::multiply_add> mul_add; + + Array frg_apply; + + using ConvertInput = NumericArrayConverter; + using ConvertOutput = NumericArrayConverter; + + ConvertInput convert_input{}; + ConvertOutput convert_output{}; + + Array frg_I = convert_input(frg_inputs); + + Tensor tCrAlpha_frg = recast>(tCrAlpha(_,_,_,epi_m,epi_n)); + Tensor tCrBias_frg = recast>(tCrBias(_,_,_,epi_m,epi_n)); + + constexpr int RegFragArraySize = FragmentSize / RegFragSize; + using RegFragArr = Array, RegFragArraySize>; + RegFragArr& frg_I_ = reinterpret_cast(frg_I); + RegFragArr& frg_apply_ = reinterpret_cast(frg_apply); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < RegFragArraySize; ++i) { + frg_apply_[i] = mul_add(tCrAlpha_frg(epi_v * RegFragArraySize + i), frg_I_[i], tCrBias_frg(epi_v * RegFragArraySize + i)); + } + + return convert_output(frg_apply); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sBias = make_tensor(make_smem_ptr(smem_bias), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sColAlpha = make_tensor(make_smem_ptr(smem_col_alpha), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sColBias = make_tensor(make_smem_ptr(smem_col_bias), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sMean = make_tensor(make_smem_ptr(smem_mean), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + + // S2R: Smem to Reg + auto tiled_s2r = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_s2r = tiled_s2r.get_slice(args.thread_idx); + Tensor tSR_sAlpha = thr_s2r.partition_S(sAlpha); + Tensor tSR_sBias = thr_s2r.partition_S(sBias); + Tensor tSR_sMean = thr_s2r.partition_S(sMean); + Tensor tSR_sInvStddev = thr_s2r.partition_S(sInvStddev); + Tensor tSR_sColAlpha = thr_s2r.partition_S(sColAlpha); + Tensor tSR_sColBias = thr_s2r.partition_S(sColBias); + Tensor tSR_cAlpha = thr_s2r.partition_S(args.cD); + + Tensor tSR_rAlpha = make_tensor_like(take<0,3>(tSR_sAlpha)); // need to check + Tensor tSR_rBias = make_tensor_like(take<0,3>(tSR_sBias)); + Tensor tSR_rMean = make_tensor_like(take<0,3>(tSR_sMean)); + Tensor tSR_rInvStddev = make_tensor_like(take<0,3>(tSR_sInvStddev)); + + Tensor tCsAlpha = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sColAlpha, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCsBias = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sColBias, args.epi_tile, args.tiled_copy, args.thread_idx); + + Tensor tCrAlpha = make_tensor_like(take<0,5>(tCsAlpha)); // (CPY,CPY_M,CPY_N) + Tensor tCrBias = make_tensor_like(take<0,5>(tCsBias)); // (CPY,CPY_M,CPY_N) + + constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ConsumerStoreCallbacks( + cute::move(tSR_rAlpha), cute::move(tSR_rBias), + cute::move(tSR_rMean), cute::move(tSR_rInvStddev), + cute::move(tSR_sAlpha), cute::move(tSR_sBias), + cute::move(tSR_sMean), cute::move(tSR_sInvStddev), + cute::move(tSR_cAlpha), + cute::move(tSR_sColAlpha), cute::move(tSR_sColBias), + cute::move(tCrAlpha), cute::move(tCrBias), + cute::move(tCsAlpha), cute::move(tCsBias), + ThreadCount{}, + params_ptr); + } +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 0000000000..3c5b627261 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm100 TMA warp-specialized (ws) epilogue +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/helper_macros.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +namespace detail { + template + CUTLASS_DEVICE auto + compute_quantized_with_row_scalefactor( + Array& frg_compute, + Array& frg_sf, + ElementCompute norm_constant) + { + cutlass::multiplies mul; + cutlass::multiplies> mul_array; + + Array frg_output; + auto output_frgs = reinterpret_cast *>(frg_output.data()); + auto compute_frgs = reinterpret_cast *>(frg_compute.data()); + + Array qpvscale_rcps = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (cute::is_same_v) { + // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate>{}(frg_sf); + return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcp); + } + else { + // UE4M3: Do the rcp in fp32 data type. + auto qpvscale_ups = cutlass::NumericArrayConverter{}(frg_sf); + return cutlass::reciprocal_approximate_ftz{}(qpvscale_ups); + } + }(); + + CUTLASS_PRAGMA_UNROLL + for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { + // norm_constant and qpvscale_rcps[sf_v] are all positive numbers. + ElementCompute acc_scale = mul(norm_constant, qpvscale_rcps[sf_v]); + // Map INF to fp32::max + acc_scale = minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Convert to output type + output_frgs[sf_v] = cutlass::NumericArrayConverter{}(mul_array(compute_frgs[sf_v], acc_scale)); + } + return frg_output; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// BlockScaleFactor Generation Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int SFVecSize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm100BlockScaleFactorRowStore { + static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); + static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or + size<1>(EpilogueTile{}) / SFVecSize == 2 or + size<1>(EpilogueTile{}) / SFVecSize == 4 or + size<1>(EpilogueTile{}) / SFVecSize == 8, + "Possible store in interleaved 4B aligned format"); + using NormalConstStrideMNL = Stride<_0,_0,int64_t>; + struct SharedStorage { }; + + struct Arguments { + ElementBlockScaleFactor* ptr_scale_factor = nullptr; + ElementCompute const* norm_constant_ptr = nullptr; + NormalConstStrideMNL norm_constant_stride = {}; + }; + + using Params = Arguments; + + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = (N % SFVecSize == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); + } + return implementable; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm100BlockScaleFactorRowStore() { } + + CUTLASS_HOST_DEVICE + Sm100BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class GTensor, + class CoordGTensor, + class ThrResidue, + class EpiTileCoordMN, + class ElementType + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) + GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + CoordGTensor tC_cSFD_, // (m,n) + ThrResidue residue_tC_cSFD_, // (m,n) + Params const* params_ptr_, + EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) + ElementType norm_constant_, + ElementType norm_constant_scaled_down_) + : tC_rSFD(cute::forward(tC_rSFD_)) + , tC_gSFD(cute::forward(tC_gSFD_)) + , tC_cSFD(tC_cSFD_) + , residue_tC_cSFD(residue_tC_cSFD_) + , params_ptr(params_ptr_) + , norm_constant(norm_constant_) + , norm_constant_scaled_down(norm_constant_scaled_down_) + , epi_tile_coord_mn(epi_tile_coord_mn_){} + + static_assert(is_same_v); + RTensor tC_rSFD; + GTensor tC_gSFD; + CoordGTensor tC_cSFD; + ThrResidue residue_tC_cSFD; + Params const* params_ptr; + ElementCompute norm_constant; + ElementCompute norm_constant_scaled_down; + EpiTileCoordMN epi_tile_coord_mn; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, + int epi_v, + int epi_m, + int epi_n, + Array const& frg_input) + { + static_assert(FragmentSize % SFVecSize == 0, "Scale factor vector size should divide FragmentSize"); + constexpr int NumVecs = FragmentSize / SFVecSize; + Array frg_compute; + + auto input_frgs = reinterpret_cast const*>(frg_input.data()); + auto compute_frgs = reinterpret_cast *>(frg_compute.data()); + + Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) + + cutlass::multiplies mul; + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + + cutlass::Array pvscales; + // SF generation + CUTLASS_PRAGMA_UNROLL + for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { + compute_frgs[sf_v] = NumericArrayConverter{}(input_frgs[sf_v]); + /// Step1: get max across a vector + ElementCompute vec_max = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); + /// Step2: Compute Scale + pvscales[sf_v] = mul(vec_max, norm_constant_scaled_down); + } + + tC_rSFD_frg(_0{}) = cutlass::NumericArrayConverter{}(pvscales); + + Tensor tCgSFD_flt = filter_zeros(tC_gSFD(_,_,_,_0{},_0{},get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n)); + Tensor tCrSFD_flt = filter_zeros(tC_rSFD); + constexpr auto MCL = decltype(max_common_layout(tCgSFD_flt, tCrSFD_flt)){}; + constexpr int V = cute::min(4, size(MCL)); + using VecType = uint_bit_t>; + Tensor tCgSFD_vec = recast(coalesce(tCgSFD_flt)); + Tensor tCrSFD_vec = recast(coalesce(tCrSFD_flt)); + Tensor tCcSFD_pred = tC_cSFD(_,_,_, epi_m, epi_n); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSFD_vec); i++){ + if (elem_less(tCcSFD_pred(i * SFVecSize * V), residue_tC_cSFD)) { + tCgSFD_vec(i) = tCrSFD_vec(i); + } + } + /// Step3: Compute quantized output values + return detail::compute_quantized_with_row_scalefactor(frg_compute, tC_rSFD_frg(_0{}), norm_constant); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; + using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; + tile_coord_l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } + + auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); + Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) + Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + + auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); + + // Fetch and compute these during initialization + Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); + ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); + ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); + ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); + ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); +#if 0 + if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ + print("epi_tile ");print(args.epi_tile); print("\n"); + print("mSFD ");print(mSFD); print("\n"); + print("gSFD ");print(gSFD); print("\n"); + print("tCgSFD ");print(tCgSFD); print("\n"); + print("tCrSFD ");print(tCrSFD); print("\n"); + print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); + print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); + } +#endif + + return ConsumerStoreCallbacks( + cute::move(tCrSFD), + cute::move(tCgSFD), + args.tCcD, + args.residue_tCcD, + params_ptr, + epi_tile_coord_mn, + norm_constant, + norm_constant_scaled_down); + + } +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/exmy_base.h b/include/cutlass/exmy_base.h new file mode 100644 index 0000000000..5c4e54603c --- /dev/null +++ b/include/cutlass/exmy_base.h @@ -0,0 +1,1219 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! + \file + \brief Generic floating-point type for ExMy format +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_size.h" +#include "cutlass/platform/platform.h" + +// #define CUTLASS_DEBUG_TRACE_LEVEL 2 +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + // Helper functions +namespace detail { + +template +CUTLASS_HOST_DEVICE +Dst copy_bits(Src src) +{ + Dst dst; + static_assert(sizeof(Src) <= sizeof(Dst), "Dst type should be at least the same size as Src type"); + static_assert(cutlass::platform::is_trivially_copyable::value, "Dst type should be trivially copyable"); + static_assert(cutlass::platform::is_trivially_copyable< + /*cutlass::platform::remove_cvref_t< */ Dst /* > */ >::value, "Dst type should be trivially copyable"); + memcpy(&dst, &src, sizeof(src)); + return dst; +} + +enum class NanInfEncoding +{ + // IEEE-754 style NaN. Exponent bits are + // all ones, and at least one bit of mantissa is one + IEEE_754, + // Canonical NaN. There is only one value representing NaN and + // no Inf is defined. + CANONICAL_ONLY, + // No NaN or Inf encoded. + NONE +}; + +enum class FpEncoding +{ + E11M52, // double + E8M23, // float + E5M2, // FP8 + E4M3, // FP8 + UE4M3, // FP8 + UE8M0, // FP8 + E3M2, // FP6 + E2M3, // FP6 + E2M1, // FP4 +}; + +////// + +#if (CUTLASS_CXX17_OR_LATER) +template +CUTLASS_CONSTEXPR_IF_CXX17 int exponent_bias_cxx17() { + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + static_assert(NumMantissaBits <= static_cast(cutlass::platform::numeric_limits::max())); + return -1 * static_cast(NumMantissaBits); + } + else { + return static_cast((1 << (NumExpBits - 1))) - 1; + } +} +#endif + +namespace impl { +template +constexpr int shift_num_bits_expression_cxx11() { +#if (__cplusplus >= 201700L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201700L)) + static_assert(NumExpBitsMinusOne <= 31u); +#endif + return NumExpBitsMinusOne > 31u ? 31u : NumExpBitsMinusOne; +} + +template +constexpr int inner_shift_expression_cxx11() { + return static_cast((1u << shift_num_bits_expression_cxx11()) - 1u); +} + +} // namespace impl + +// C++11 equivalent of exponent_bias_cxx17() +template +constexpr int exponent_bias_cxx11() { +#if (__cplusplus >= 201700L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201700L)) + return exponent_bias_cxx17(); +#else + return (NumExpBits == 0) ? + -1 * static_cast(NumMantissaBits) : impl::inner_shift_expression_cxx11(); +#endif +} + +// C++11 equivalent of maximum_exponent_cxx17() +template +constexpr int maximum_exponent_cxx11() { + return + ((NumExpBits == 0) ? + (0 - exponent_bias_cxx11()) : + ((NaNEncoding == NanInfEncoding::IEEE_754) ? + ((static_cast((1 << NumExpBits)) - 2) - exponent_bias_cxx11()) : + ((NaNEncoding == NanInfEncoding::CANONICAL_ONLY) ? + ((NumMantissaBits > 0) ? + static_cast((1 << NumExpBits)) - 1 - exponent_bias_cxx11() : + static_cast((1 << NumExpBits)) - 2 - exponent_bias_cxx11() + ) : + (static_cast((1 << NumExpBits)) - 1 - exponent_bias_cxx11()) + ) + ) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr int maximum_exponent_cxx17() { + constexpr int exp_bias = exponent_bias_cxx17(); + if constexpr (NumExpBits == 0) { + // If no exponent bits, return fixed hidden bias + return 0 - exp_bias; + } + else { + if constexpr (NaNEncoding == NanInfEncoding::IEEE_754) { + // We have IEEE style NaN and infinity + // All values when exp_bits = 1...1s are used. + int max_exp_bits = static_cast((1 << NumExpBits)) - 2; + return max_exp_bits - exp_bias; + } + else { + // There are no cases where we have Inf without IEEE_754_Nan + + // If we have a canonical NaN. Only exp=1..1 and mantissa=1..1 + // value has a special meaning. If we also have at least one mantissa + // bit, then maximum exponent is 1...1 - exponent_bias + if constexpr (NaNEncoding == NanInfEncoding::CANONICAL_ONLY) { + if constexpr (NumMantissaBits > 0) { + int max_exp_bits = static_cast((1 << NumExpBits)) - 1; + return max_exp_bits - exp_bias; + } + else { // no mantissa bits + int max_exp_bits = static_cast((1 << NumExpBits)) - 2; + return max_exp_bits - exp_bias; + } + } + // No NaNs or infs + int max_exp_bits = static_cast((1 << NumExpBits)) - 1; + return max_exp_bits - exp_bias; + } + } +} +#endif + +template +constexpr int minimum_exponent_cxx11() { + return + ((NumExpBits == 0) ? + 0 - exponent_bias_cxx11() : + ((NumMantissaBits > 0) ? + 1 - exponent_bias_cxx11() : + 0 - exponent_bias_cxx11()) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr int minimum_exponent_cxx17() { + constexpr int exp_bias = exponent_bias_cxx17(); + constexpr bool has_denorm = (NumMantissaBits > 0); + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + // If no exponent bits, return fixed hidden bias + // Note that minimum and maximum exponents are the same. + return 0 - exp_bias; + } + + if CUTLASS_CONSTEXPR_IF_CXX17 (has_denorm) { + // Exp = 0...0s is reserved for denorm values. + return 1 - exp_bias; + } + return 0 - exp_bias; +} +#endif + +template +constexpr Storage max_pos_denormal_value_cxx11() { + static_assert(NumExpBits > 0 || NumMantissaBits > 0, "Both NumExpBits and NumMantissaBits can't be zero"); + return + (!(NumMantissaBits > 0) ? Storage(0) : Storage((1ull << NumMantissaBits) - 1)); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage max_pos_denormal_value_cxx17() { + static_assert(NumExpBits > 0 || NumMantissaBits > 0, "Both NumExpBits and NumMantissaBits can't be zero"); + constexpr bool has_denorm = (NumMantissaBits > 0); + if constexpr (!has_denorm) { + // If we don't have denormal values, return all 0s + return Storage(0); + } + else { + // Case: (NumExpBits > 0 && NumMantissaBits > 0) or (NumExpBits == 0 && NumMantissaBits > 0) + return Storage((1ull << NumMantissaBits) - 1); + } +} +#endif + + +template +constexpr Storage min_pos_denormal_value_cxx11() { + return (!(NumMantissaBits > 0) ? Storage(0) : Storage(1)); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage min_pos_denormal_value_cxx17() { + constexpr bool has_denorm = (NumMantissaBits > 0); + if constexpr (!has_denorm) { + // If we don't have denormal values, return all 0s + return Storage(0); + } + // Case: (NumExpBits > 0 && NumMantissaBits > 0) or (NumExpBits == 0 && NumMantissaBits > 0) + return Storage(1); +} +#endif + +template +constexpr Storage max_pos_normal_value_cxx11() { + return + ((NumExpBits == 0) ? + Storage(0) : + ((NumMantissaBits == 0) ? + 0 : + (((NaNEncoding == NanInfEncoding::IEEE_754 || NaNEncoding == NanInfEncoding::NONE) ? + ((1ull << NumMantissaBits) - 1) : + ((1ull << NumMantissaBits) - 2))) + ) | (static_cast( + maximum_exponent_cxx11() + + exponent_bias_cxx11() + ) << NumMantissaBits) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage max_pos_normal_value_cxx17() { + if constexpr (NumExpBits == 0) { + // if there are no exponent bits, we don't have normal values. + return Storage(0); + } + constexpr int exp_bias = exponent_bias_cxx17(); + constexpr int max_exp = maximum_exponent_cxx17(); + constexpr int exp = max_exp + exp_bias; + + // place the exponent + Storage val = static_cast(exp) << NumMantissaBits; + // If there are no mantissa bits return the exponent + if constexpr (NumMantissaBits == 0) { + return val; + } + else { + // If the NaN Inf encoding follows IEEE 754 or there is no (NaN and Inf) then mantissa can be all 1..1s + if constexpr (NaNEncoding == NanInfEncoding::IEEE_754 || + NaNEncoding == NanInfEncoding::NONE ) { + Storage mantissa = (1ull << NumMantissaBits) - 1; + val |= mantissa; + } + else { + // If we have a canonical NaN, then the exponent can be the maximum bit value + // but mantissa=1..1s is reserved for NaN. + Storage mantissa = (1ull << NumMantissaBits) - 2; + val |= mantissa; + } + return val; + } +} +#endif + +template +constexpr Storage min_pos_normal_value_cxx11() { + return + ((NumExpBits == 0) ? + Storage(0) : + (Storage((NumMantissaBits > 0) ? 1 : 0) << NumMantissaBits) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage min_pos_normal_value_cxx17() { + constexpr bool has_denorm = (NumMantissaBits > 0); + + if constexpr (NumExpBits == 0) { + // if there are no exponent bits, we don't have normal values. + return Storage(0); + } + Storage exp = 0; + if constexpr (has_denorm) { + exp = 1; + } + return static_cast(exp << NumMantissaBits); +} +#endif + +template +constexpr Storage max_value_cxx11() { + return + ((NumExpBits > 0) ? + max_pos_normal_value_cxx11() : + max_pos_denormal_value_cxx11() + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage max_value_cxx17() { + constexpr bool has_normal = (NumExpBits > 0); + if (has_normal) { + return max_pos_normal_value_cxx17(); + } + else { + return max_pos_denormal_value_cxx17(); + } +} +#endif + +template +constexpr Storage min_value_cxx11() { + return + (IsSigned ? + Storage(1ull << (NumExpBits + NumMantissaBits)) | max_value_cxx11() : + Storage(0) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage min_value_cxx17() { + if (IsSigned) { + return Storage(1ull << (NumExpBits + NumMantissaBits)) | max_value_cxx17(); + } + else { // Unsigned number + return Storage(0); + } +} +#endif + +template < + class StorageType, + uint32_t NumBits, uint32_t NumExpBits, uint32_t NumMantissaBits, + NanInfEncoding Nan = NanInfEncoding::IEEE_754, bool IsSigned = true> +struct FpBitRepresentation { +public: + + using Storage = StorageType; + +#if (201700L <= __cplusplus) + static_assert(cutlass::platform::is_unsigned_v, "Use an unsigned integer for StorageType"); +#endif + static constexpr bool IS_SIGNED = IsSigned; + // Canonical NaN is always represented as exponent=11...11 and mantissa=11...11, if it exists + static constexpr NanInfEncoding NAN_TYPE = Nan; + // Inf is always represented as exponent=11...11 and mantissa=00...00, if it exists + static constexpr bool HAS_INF = (NAN_TYPE == NanInfEncoding::IEEE_754); + static constexpr bool HAS_NAN = (NAN_TYPE != NanInfEncoding::NONE); + + static constexpr bool HAS_DENORM = (NumMantissaBits > 0); + static constexpr bool HAS_NORMAL = !HAS_DENORM; + + static constexpr uint32_t NUM_BITS = NumBits; + static constexpr uint32_t NUM_EXPONENT_BITS = NumExpBits; + static constexpr uint32_t NUM_MANTISSA_BITS = NumMantissaBits; + static_assert(NUM_BITS >= (NUM_EXPONENT_BITS + NUM_MANTISSA_BITS + uint32_t(IS_SIGNED)), "Number of bits do not match"); + + static constexpr Storage ONE = Storage(1); + static constexpr Storage ZERO = Storage(0); + + // Note: Don't rely on operator precedence. Use parenthesis. + static constexpr Storage EXPONENT_MASK = (Storage(1) << Storage(NUM_EXPONENT_BITS)) - ONE; + static constexpr Storage MANTISSA_MASK = (Storage(1) << Storage(NUM_MANTISSA_BITS)) - ONE; + static constexpr Storage EXPONENT_SHIFT = Storage(NUM_MANTISSA_BITS); + static constexpr Storage SIGN_SHIFT = (IS_SIGNED) ? Storage(NUM_MANTISSA_BITS + NUM_EXPONENT_BITS) : Storage(0); + + // Note: All biased/real exponent calculation are done with signed ints + // Use unsigned to represent data not exponent. + static constexpr int EXP_BIAS = detail::exponent_bias_cxx11(); + static constexpr int MAX_EXP = detail::maximum_exponent_cxx11(); + static constexpr int MIN_EXP = detail::minimum_exponent_cxx11(); + + // Floating-point Limits + static constexpr Storage MAX_POS_NORMAL_VAL = detail::max_pos_normal_value_cxx11(); + static constexpr Storage MAX_POS_DENORMAL_VAL = detail::max_pos_denormal_value_cxx11(); + static constexpr Storage MIN_POS_NORMAL_VAL = detail::min_pos_normal_value_cxx11(); + static constexpr Storage MIN_POS_DENORMAL_VAL = detail::min_pos_denormal_value_cxx11(); + + static constexpr Storage MAX_VALUE = max_value_cxx11(); + static constexpr Storage MIN_VALUE = min_value_cxx11(); + + // + // C++17 Verification + // +#if (CUTLASS_CXX17_OR_LATER) + static_assert(EXP_BIAS == detail::exponent_bias_cxx17(), "Error"); + static_assert(MAX_EXP == detail::maximum_exponent_cxx17(), "Error"); + static_assert(MIN_EXP == detail::minimum_exponent_cxx17(), "Error"); + + static_assert(MAX_POS_NORMAL_VAL == detail::max_pos_normal_value_cxx17(), "Error"); + static_assert(MAX_POS_DENORMAL_VAL == detail::max_pos_denormal_value_cxx17(), "Error"); + static_assert(MIN_POS_NORMAL_VAL == detail::min_pos_normal_value_cxx17(), "Error"); + static_assert(MIN_POS_DENORMAL_VAL == detail::min_pos_denormal_value_cxx17(), "Error"); + static_assert(MAX_VALUE == max_value_cxx17(), "Error"); + static_assert(MIN_VALUE == min_value_cxx17(), "Error"); +#endif + + // If we don't have INF defined, set the largest number. Gives us .satfinite behavior. + static constexpr Storage INF_MASK = (HAS_INF) ? + (Storage(EXPONENT_MASK) << Storage(NUM_MANTISSA_BITS)) : MAX_VALUE; + static constexpr Storage NAN_MASK = (Storage(EXPONENT_MASK) << Storage(NUM_MANTISSA_BITS)) | MANTISSA_MASK; + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_inf(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!HAS_INF) { + return false; + } + bool exp_all_ones = (exponent_bits(flt) ^ EXPONENT_MASK) == 0; + bool mantissa_all_zeros = mantissa_bits(flt) == 0; + return exp_all_ones && mantissa_all_zeros; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_canonical_nan(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NAN_TYPE == NanInfEncoding::NONE) { + return false; + } + bool exp_all_ones = (exponent_bits(flt) ^ EXPONENT_MASK) == ZERO; + bool mantissa_all_ones = (mantissa_bits(flt) ^ MANTISSA_MASK) == ZERO; + return exp_all_ones && mantissa_all_ones; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_nan(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NAN_TYPE == NanInfEncoding::NONE) { + return false; + } + + if CUTLASS_CONSTEXPR_IF_CXX17 (NAN_TYPE == NanInfEncoding::CANONICAL_ONLY) { + return is_canonical_nan(flt); + } + + bool exp_all_ones = (exponent_bits(flt) ^ EXPONENT_MASK) == ZERO; + bool mantissa_has_ones = mantissa_bits(flt) > ZERO; + return exp_all_ones && mantissa_has_ones; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_denorm(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!HAS_DENORM) { + return false; + } + else if (exponent_bits(flt) == ZERO) { + // Exponent bits are all 0s + return true; + } + return false; + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T sign_bit(T flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!IS_SIGNED) { + return T(0); + } + return static_cast(flt >> T(SIGN_SHIFT)); + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T set_sign_bit(T flt, T sign) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!IS_SIGNED) { + return flt; + } + return static_cast(flt | (sign << T(SIGN_SHIFT))); + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage exponent_bits(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_EXPONENT_BITS == ZERO) { + return ZERO; + } + return (flt >> (NUM_MANTISSA_BITS)) & EXPONENT_MASK; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 int exponent(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_EXPONENT_BITS == ZERO) { + return -int(EXP_BIAS); + } + + if (HAS_DENORM && (exponent_bits(flt) == ZERO)) { + return 1 - int(EXP_BIAS); + } + + return int(flt >> (NUM_MANTISSA_BITS) & EXPONENT_MASK) - int(EXP_BIAS); + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage mantissa_bits(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_MANTISSA_BITS == ZERO) { + return ZERO; + } + return (flt & MANTISSA_MASK); + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage to_bits(FpType flt) { + return copy_bits(flt); + } + + template + CUTLASS_HOST_DEVICE static typename DstFpBits::Storage convert_to( + Storage src_val, + DstFpBits dst_encoding) { + return convert(FpBitRepresentation{}, src_val, dst_encoding); + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage convert_from( + typename SrcFpBits::Storage src_val, + SrcFpBits src_encoding) { + return convert(src_encoding, src_val, FpBitRepresentation{}); + } + +private: + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T make_fp_from_bits(T sign, T exp, T mantissa) { + T fp_bits = T(ZERO); + CUTLASS_UNUSED(sign); + if CUTLASS_CONSTEXPR_IF_CXX17 (IS_SIGNED) { + fp_bits = sign << SIGN_SHIFT; + } + fp_bits |= (exp << T(NUM_MANTISSA_BITS)); + fp_bits |= (mantissa); + return fp_bits; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage nan_with_sign(Storage sign) { + Storage fp_bits = NAN_MASK; + return set_sign_bit(fp_bits, sign); + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage inf_with_sign(Storage sign) { + if CUTLASS_CONSTEXPR_IF_CXX17 (HAS_INF) { + Storage fp_bits = INF_MASK; + return set_sign_bit(fp_bits, sign); + } + else { + // If INF is not defined assume satfinite behavior + return (sign == ZERO) ? MAX_VALUE : MIN_VALUE; + } + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage significand(Storage flt) { + if (is_denorm(flt)) { + return mantissa_bits(flt); + } + else { + return (ONE << Storage(NUM_MANTISSA_BITS)) | mantissa_bits(flt); + } + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T significand_hidden_bits(T significand) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_MANTISSA_BITS == 0) { + return T(1); + } + return ((T(0b11) << T(NUM_MANTISSA_BITS)) & significand) >> T(NUM_MANTISSA_BITS); + } + + // Current assumption round to nearest even + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T round_significand(T src, int shift_amount) { + T dst_mantissa = src; + // If the shift amount is positive, we are shifting left + // Type with less mantissa bits is rounded to a type with more + // mantissa bits. + if (shift_amount > 0) { + dst_mantissa = (dst_mantissa << (shift_amount)); + } + else { + // There are fewer mantissa bits in the target type + // we need to round the destination number up for all + // lower precision bits removed. + // We assume round-to-nearest-even here. + int pos_shift_amount = -shift_amount; + + // Too large shift return all zeros to prevent undefined behavior for shift. + if (pos_shift_amount >= static_cast(sizeof(T) * 8)) { + return T(0); + } + + T guard_bit_mask = (T(1) << T(pos_shift_amount)); // Last bit to remain in mantissa + T sticky_mask = (T(1) << T(pos_shift_amount - 1)) - T(1); // Remaining bits + T round_bit_mask = (T(1) << T(pos_shift_amount - 1)); // First bit removed from mantissa + + bool sticky_bit = (src & sticky_mask) >= T(1); // ORing all sticky bits + bool round_bit = (src & round_bit_mask) >= T(1); + bool guard_bit = (src & guard_bit_mask) >= T(1); + + // Shift mantissa bits to right to remove lowest precision bits + dst_mantissa = dst_mantissa >> pos_shift_amount; + + if ((sticky_bit && round_bit) || (guard_bit && round_bit && !sticky_bit)) { + dst_mantissa += 1; + } + } + return dst_mantissa; + } + + template + CUTLASS_HOST_DEVICE + static typename DstFpBits::Storage convert( + SrcFpBits src_encoding, + typename SrcFpBits::Storage src_val, + DstFpBits dst_encoding) { + + using SrcT = typename SrcFpBits::Storage; + using DstT = typename DstFpBits::Storage; + using LargeStorage = typename cutlass::platform::conditional<(sizeof(SrcT) > sizeof(DstT)), SrcT, DstT>::type; + + + LargeStorage src_sign_bit = src_encoding.sign_bit(src_val); + + // If the source is NaN, set the destination to NaN carrying the sign bit + if (src_encoding.is_nan(src_val)) { + return dst_encoding.nan_with_sign(DstT(src_sign_bit)); + } + // If the source is INF, set the destination to INF carrying the sign bit + else if (src_encoding.is_inf(src_val)) { + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + // Number is not NaN or INF: Zero and others + + LargeStorage src_exp_bits = src_encoding.exponent_bits(src_val); + LargeStorage src_significand = src_encoding.significand(src_val); + int src_exp = src_encoding.exponent(src_val); + + // The source value is 0. Return a signed 0. + if (src_exp_bits == LargeStorage(0) && src_significand == LargeStorage(0)) { + return dst_encoding.set_sign_bit(DstT(0), DstT(src_sign_bit)); + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(1) src_sign: %llu src_exp_bits %llx src_exp %d src_significand %llx\n", + static_cast(src_sign_bit), static_cast(src_exp_bits), src_exp, static_cast(src_significand)); +#endif + // Normalize the number: Left shift the significand bits until hidden "1" appears. + // Only needed if the src value is denormal. + // Conditions: + // If the exponent is 0, then the significand can't be 0 (src_val==0 case handled above): + // there is at least one "1" bit in the significand. Loop executes. + // If the exponent is not 0, then the number is normal: + // significand has hidden bit set. Loop doesn't execute. + // Assumption: Zero is always defined for the floating point types and detected above + + while (src_encoding.significand_hidden_bits(src_significand) == LargeStorage(0)) { + src_significand <<= LargeStorage(1); + src_exp--; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(2) src_sign: %llu src_exp_bits %llx src_exp %d src_significand %llx\n", + static_cast(src_sign_bit), static_cast(src_exp_bits), src_exp, static_cast(src_significand)); +#endif + // The exponent exceeds DstFormat's exponent capacity + // Return positive/negative infinity. + // If no INF is defined, return positive/negative largest value. + if (src_exp > DstFpBits::MAX_EXP) { + + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + else if (src_exp <= DstFpBits::MAX_EXP && src_exp >= DstFpBits::MIN_EXP) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(3) Exp match: src_sign: %d src_exp_bits: %x src_exp: %d src_significand: %x\n", + src_sign_bit, src_exp_bits, src_exp, src_significand); +#endif + + int shift_amount = int(DstFpBits::NUM_MANTISSA_BITS) - int(SrcFpBits::NUM_MANTISSA_BITS); + int dst_exponent = src_exp + DstFpBits::EXP_BIAS; + LargeStorage dst_mantissa = src_significand; + + // if we have an M0 case, the floating point number is always denormal. + // Therefore, if exponents are equal, we need to check whether it is inf + if (DstFpBits::NUM_EXPONENT_BITS == 0) { + if (dst_mantissa > DstFpBits::INF_MASK) { + return dst_encoding.inf_with_sign(DstT(src_sign_bit)); + } + } + + // Round to nearest even + dst_mantissa = round_significand(dst_mantissa, shift_amount); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(4) after rounding src_sign: %d dst_exponent: %d dst_mantissa: %x\n", + src_sign_bit, dst_exponent, dst_mantissa); +#endif + + // TODO potential narrowing here + if (dst_encoding.significand_hidden_bits(dst_mantissa) > 0b1) { + + // Significant became larger than 01.X...X. Divide significand by 2 and multiply exp by 2 + while (dst_exponent < (DstFpBits::MAX_EXP+DstFpBits::EXP_BIAS) && + dst_encoding.significand_hidden_bits(dst_mantissa) > LargeStorage(0b1)) { + dst_mantissa >>= LargeStorage(1); + dst_exponent++; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(5) after rounding max_exp: %d src_sign: %d dst_exponent: %d dst_mantissa: %x\n", + DstFpBits::MAX_EXP,src_sign_bit, dst_exponent, dst_mantissa); +#endif + + if (dst_encoding.significand_hidden_bits(dst_mantissa) > LargeStorage(0b1)) { + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + } + + dst_mantissa = dst_mantissa & DstFpBits::MANTISSA_MASK; + static_assert(sizeof(LargeStorage) >= sizeof(decltype(dst_exponent)), + "sizeof(LargeStorage) must be greater than or equal to sizeof(decltype(dst_exponent))"); + LargeStorage dst_exponent_bits = static_cast(dst_exponent); + + DstT final_val = static_cast(dst_encoding.template make_fp_from_bits(src_sign_bit, dst_exponent_bits, dst_mantissa)); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(6) Final Value src_sign: %d dst_exp_bits: %x dst_mantissa: %x\n", + src_sign_bit, dst_exponent_bits, dst_mantissa); +#endif + + if (DstFpBits::is_nan(final_val)) { + // This NAN is generated when: + // Src is not an Nan + // the exp of Src == the max_exp of Dst. + // The mantissa becomes all-1s after rounding. + // Return max value of Dst (not NAN) as it just couldn't be represented in the range of Dst. + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + else { + return final_val; + } + } + else { + // Result is denormal +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(7) Denormal case src_sign: %d src_exp: %d src_significand: %x MIN_EXP: %d\n", + src_sign_bit, src_exp, src_significand, DstFpBits::MIN_EXP); +#endif + + int exp_diff = src_exp - DstFpBits::MIN_EXP; + int shift_amount = int(DstFpBits::NUM_MANTISSA_BITS) - int(SrcFpBits::NUM_MANTISSA_BITS); + shift_amount += exp_diff; + LargeStorage dst_mantissa = src_significand; + dst_mantissa = round_significand(dst_mantissa, shift_amount); + + if (dst_encoding.significand_hidden_bits(dst_mantissa) >= LargeStorage(0b1)) { + if CUTLASS_CONSTEXPR_IF_CXX17 (DstFpBits::NUM_EXPONENT_BITS == 0) { + return dst_encoding.inf_with_sign(DstT(src_sign_bit)); + } + else { + LargeStorage dst_exp_bits = 1; + dst_mantissa &= DstFpBits::MANTISSA_MASK; + DstT final_val = static_cast(dst_encoding.template make_fp_from_bits(src_sign_bit, dst_exp_bits, dst_mantissa)); + return final_val; + } + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(7.1) Denormal case exp_diff: %d shift_amount: %d dst_mantissa %d\n", exp_diff, shift_amount, dst_mantissa); +#endif + dst_mantissa &= DstFpBits::MANTISSA_MASK; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(8) Final Value src_sign: %d src_exp: %d dst_mantissa: %x\n", + src_sign_bit, src_exp, dst_mantissa); +#endif + + DstT final_val = static_cast(dst_encoding.template make_fp_from_bits(src_sign_bit, LargeStorage(0), dst_mantissa)); + return final_val; + } + + return DstT(0); + } + + template + friend struct FpBitRepresentation; +}; + +#if (CUTLASS_CXX17_OR_LATER) + +template +CUTLASS_CONSTEXPR_IF_CXX17 auto fp_encoding_selector() { + if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E11M52) { // double + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E8M23) { // float + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E5M2) { // FP8 + // TODO: Not tested. Will be done in another MR + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E4M3) { // FP8 + // TODO: Not tested. Will be done in another MR + return cutlass::detail::FpBitRepresentation{}; + } + + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE4M3) { // FP8 + // TODO: Not tested. Will be done in another MR + return cutlass::detail::FpBitRepresentation{}; + } + + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE8M0) { // FP8 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E3M2) { // FP6 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E2M3) { // FP6 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E2M1) { // FP4 + return cutlass::detail::FpBitRepresentation{}; + } + else { + CUTLASS_GCC_UNREACHABLE; + } +} + +#else +// +// Definitions for floating point encodings. +// + +template struct FpEncodingSelector { + using type = void; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; +#endif + +} // namespace detail + +template +struct float_exmy_base +{ + + static constexpr detail::FpEncoding Encoding = T; + using BitRepresentation = + #if (CUTLASS_CXX17_OR_LATER) + decltype(detail::fp_encoding_selector()) + #else + typename detail::FpEncodingSelector::type + #endif + ; + + using FP32BitRepresentation = + #if (CUTLASS_CXX17_OR_LATER) + decltype(cutlass::detail::fp_encoding_selector()) + #else + typename detail::FpEncodingSelector::type + #endif + ; + + using Storage = typename BitRepresentation::Storage; + + // + // Data members + // + + /// Data container + Storage storage; + + /// Ctors. + float_exmy_base() = default; + + CUTLASS_HOST_DEVICE + float_exmy_base(Storage s) : storage(s) { + } + + /// Is finite implementation + CUTLASS_HOST_DEVICE + static bool isfinite(float_exmy_base flt) { + return !BitRepresentation::is_inf(flt.storage); + } + + /// Is NaN implementation + CUTLASS_HOST_DEVICE + static bool isnan(float_exmy_base flt) { + return BitRepresentation::is_nan(flt.storage); + } + + /// Is infinite implementation + CUTLASS_HOST_DEVICE + static bool isinf(float_exmy_base flt) { + return BitRepresentation::is_inf(flt.storage); + } + + /// Is infinite implementation + CUTLASS_HOST_DEVICE + static bool isnormal(float_exmy_base flt) { + return !BitRepresentation::is_denorm(flt.storage); + } + + CUTLASS_HOST_DEVICE + static float_exmy_base bitcast(Storage x) { + float_exmy_base f; + f.storage = x; + return f; + } + + // TODO: Add rounding parameter with a reasonable default + CUTLASS_HOST_DEVICE + float_exmy_base convert_from_float(float const &flt) const { + // TODO: If we have a cvt instruction specialize in the children structs + FP32BitRepresentation::Storage fp32_bits = FP32BitRepresentation::to_bits(flt); + float_exmy_base float_exmy; + float_exmy.storage = BitRepresentation::convert_from(fp32_bits, FP32BitRepresentation{}); + return float_exmy; + } + + // TODO: Add rounding parameter with a reasonable default + CUTLASS_HOST_DEVICE + float convert_to_float(float_exmy_base const &x) const { + // TODO: If we have a cvt instruction specialize in the children structs + FP32BitRepresentation::Storage fp32_bits; + fp32_bits = BitRepresentation::convert_to(x.storage, FP32BitRepresentation{}); + return detail::copy_bits(fp32_bits); + } + + // Note: Only consider float/int conversions in this Base class + // Types inheriting from this class should define their own constructors and + // specialized type conversions + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_exmy_base(float x) { + storage = static_cast(this)->convert_from_float(x).storage; + } + + // Integer conversion + CUTLASS_HOST_DEVICE + explicit float_exmy_base(int x) { + storage = static_cast(this)->convert_from_float(float(x)).storage; + } + + CUTLASS_HOST_DEVICE + explicit float_exmy_base(unsigned x) { + storage = static_cast(this)->convert_from_float(float(x)).storage; + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + return static_cast(this)->convert_to_float(*this); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(static_cast(this)->convert_to_float(*this)); + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + Storage &raw() { + return storage; + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + Storage raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return bool(BitRepresentation::sign_bit(storage)); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int(BitRepresentation::exponent_bits(storage)); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return int(BitRepresentation::exponent(storage)); + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(BitRepresentation::mantissa_bits(storage)); + } + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // + // Arithmetic operators + // + /////////////////////////////////////////////////////////////////////////////////////////////////// + + // Note: Almost all data types cast to float then do the arithmetic operations + // Types inheriting from this class can overload them if specialized instructions are available + // in HW (e.g. half_t) + + + CUTLASS_HOST_DEVICE + friend bool operator==(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) == float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator!=(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) != float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator<(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) < float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator<=(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) <= float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator>(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) > float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator>=(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) >= float(rhs); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator+(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) + float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator-(float_exmy_base const &lhs) { + return float_exmy_base(-float(lhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator-(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) - float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator*(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) * float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator/(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) / float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator+=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) + float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator-=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) - float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator*=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) * float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator/=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) / float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator++(float_exmy_base &lhs) { + float tmp(lhs); + ++tmp; + lhs = float_exmy_base(tmp); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator--(float_exmy_base &lhs) { + float tmp(lhs); + --tmp; + lhs = float_exmy_base(tmp); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator++(float_exmy_base &lhs, int) { + float_exmy_base ret(lhs); + float tmp(lhs); + tmp++; + lhs = float_exmy_base(tmp); + return ret; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator--(float_exmy_base &lhs, int) { + float_exmy_base ret(lhs); + float tmp(lhs); + tmp--; + lhs = float_exmy_base(tmp); + return ret; + } + +}; + +template +CUTLASS_HOST_DEVICE +cutlass::float_exmy_base abs(cutlass::float_exmy_base const& h) { + using BitRepresentation = typename cutlass::float_exmy_base::BitRepresentation; + using Storage = typename cutlass::float_exmy_base::Storage; + return BitRepresentation::IS_SIGNED ? + cutlass::float_exmy_base(Storage(h.raw() & Storage((1<(h.raw()); +} +} // namespace cutlass diff --git a/include/cutlass/experimental/distributed/device/full_barrier.hpp b/include/cutlass/experimental/distributed/device/full_barrier.hpp index 8ac9940eef..54b8348d6c 100644 --- a/include/cutlass/experimental/distributed/device/full_barrier.hpp +++ b/include/cutlass/experimental/distributed/device/full_barrier.hpp @@ -47,7 +47,7 @@ void launch_full_barrier( cudaStream_t stream, bool launch_with_pdl) { -#if ((__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) // Legacy (kernel) launch with PDL cudaLaunchAttribute attributes[1]; attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 2f462286d2..34acd40a20 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -36,6 +36,10 @@ #pragma once + +#include "cutlass/arch/config.h" + + // FP8 types are available starting CUDA 11.8+ #if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) #define CUDA_FP8_ENABLED 1 @@ -53,6 +57,12 @@ # endif // (__CUDA_ARCH__ >= 900) #endif // defined(__CUDA_ARCH__) + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) +# define CUDA_PTX_UE8M0_CVT_ENABLED 1 +#endif + + #ifdef __GNUC__ // Ignore checks on reinterpret-casts that are being used for bitcasts. #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -80,6 +90,12 @@ #include #include "cutlass/cutlass.h" + +#include "cutlass/exmy_base.h" + +#include "cute/util/type_traits.hpp" + + /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -1028,6 +1044,204 @@ float_e5m2_t operator--(float_e5m2_t & lhs, int) { return ret; } + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : UE4M3 +/// +/////////////////////////////////////////////////////////////// +// UE4M3: +// 4 Exponent bits, 3 Mantissa bits +// Range: [0:448] +// has_inf: false +// has_NaN: true +// has_denorm: true +// Exponent bias (exp_bias): 7 +struct float_ue4m3_t : public float_exmy_base { + using Base = float_exmy_base; + + float_ue4m3_t() = default; + + CUTLASS_HOST_DEVICE + float_ue4m3_t convert_from_float(float const &flt) const { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp; + float y = float(); + asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); + return bitcast(*reinterpret_cast(&tmp)); + #else + Base::FP32BitRepresentation::Storage fp32_bits = Base::FP32BitRepresentation::to_bits(flt); + return bitcast(BitRepresentation::convert_from(fp32_bits, Base::FP32BitRepresentation{})); + #endif + } + + CUTLASS_HOST_DEVICE + float convert_to_float(float_ue4m3_t const &x) const { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + return __half2float(reinterpret_cast(packed).x); + #else + Base::FP32BitRepresentation::Storage fp32_bits; + fp32_bits = Base::BitRepresentation::convert_to(x.storage, Base::FP32BitRepresentation{}); + return detail::copy_bits(fp32_bits); + #endif + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(unsigned x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_ue4m3_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_ue4m3_t const& x) { + return x.storage == uint8_t(0x7f); + } + +}; + +/// Defines the size of an element in bits - specialized for float_ue4m3_t +template <> +struct sizeof_bits { + static constexpr int value = sizeof_bits>::value; +}; + + + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : UE8M0 +/// +/////////////////////////////////////////////////////////////// +// UE8M0: +// 8 Exponent bits, 0 Mantissa bits +// Range: [2^-127:2^127] +// has_inf: false +// has_NaN: true (11111111) +// has_denorm: true +// Exponent bias (exp_bias): 8 + +struct float_ue8m0_t : public float_exmy_base { + using Base = float_exmy_base; + using FP32Bits = typename Base::FP32BitRepresentation; + + float_ue8m0_t() = default; + + CUTLASS_HOST_DEVICE + float_ue8m0_t convert_from_float(float const &flt) const { + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint16_t out; + asm volatile( + "{ cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1; }" + : "=h"(out) : "f"(flt)); + return bitcast(*reinterpret_cast(&out)); + #else + if (CUTLASS_CMATH_NAMESPACE::isnan(flt) || CUTLASS_CMATH_NAMESPACE::isinf(flt)) { + return bitcast(0xFF); + } + uint32_t flt_uint32 = cutlass::detail::copy_bits(flt); + uint8_t exp = (flt_uint32 >> 23) & 0xff; // Extract the 8 bit exponent + uint32_t mant = flt_uint32 & 0x7fffff; // Extract the 23 bit mantissa + // Do the round up + // Deals w/ satfinite all at once + if ((mant > 0) && (exp != 0xFE) && !(exp == 0 && mant <= 0x00400000)) { + exp++; + } + return bitcast(exp); + #endif + } + + CUTLASS_HOST_DEVICE + float convert_to_float(float_ue8m0_t const &x) const { + ////////////////////////////////////////////////////////////// + // The conversion of UE8M0 to FP32 scale can be done simply + // with a left shift (No rounding necessary) + // Note: The base class implements ue8m0 to FP32 based on the rules of float math conversions. + // The result of current implementation and base class are aligned. + ////////////////////////////////////////////////////////////// + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t bf16x2_val; + // E8 -> BF16 + asm volatile( + "{\n" + "cvt.rn.bf16x2.ue8m0x2 %0, %1;\n" + "}\n" : "=r"(bf16x2_val): "h"(bits)); + // BF16 -> FP32 + float f1; + asm( + "{\n" + "prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=f"(f1) + : "r"(0), "r"(bf16x2_val), "r"(0x5410)); + return f1; + #else + using FP32Bits = cutlass::detail::FpBitRepresentation; + if (x.storage == 0x00) { + return cutlass::detail::copy_bits(0x00400000); + } + else if (x.storage == 0xFF) { + return cutlass::detail::copy_bits(0x7fffffff); + } + else { + auto f8 = static_cast(x.storage); + FP32Bits::Storage f = (f8 << FP32Bits::NUM_MANTISSA_BITS); + return cutlass::detail::copy_bits(f); + } + #endif + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(unsigned x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_ue8m0_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_ue8m0_t const& x) { + return x.storage == uint8_t(0xff); + } + +}; + +/// Defines the size of an element in bits - specialized for float_ue8m0_t +template <> +struct sizeof_bits { + static constexpr int value = sizeof_bits>::value; +}; + + /////////////////////////////////////////////////////////////////////////////////////////////////// // // float_e4m3_t <=> float_e5m2_t conversions @@ -1074,6 +1288,26 @@ union type_erased_dynamic_float8_t { }; + + +/////////////////////////////////////////////////////////////// +/// MX type for float8 +/// Intended to be used in builders +/////////////////////////////////////////////////////////////// + +template +struct mx_float8_t { + static_assert(cute::is_same_v + || cute::is_same_v + || cute::is_same_v + , "Only float_e5m2_t, float_e4m3_t can have scale factors for MXFP8"); + using ScaleFactorType = cutlass::float_ue8m0_t; + using DataType = F8Type; +}; + +using type_erased_dynamic_mx_float8_t = mx_float8_t; + + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass @@ -1162,6 +1396,73 @@ struct numeric_limits : static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } }; + +template +struct float8_exmy_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + CUTLASS_HOST_DEVICE + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + CUTLASS_HOST_DEVICE + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + CUTLASS_HOST_DEVICE + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + CUTLASS_HOST_DEVICE + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + CUTLASS_HOST_DEVICE + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE + static type denorm_min() { return type::bitcast(0x01); } +}; + +/// Numeric limits for float_ue8m0_t +template <> +struct numeric_limits : + public float8_exmy_numeric_limits { + static bool const has_infinity = false; + static bool const is_signed = false; + + /// Minimum finite value + static cutlass::float_ue8m0_t lowest() { return cutlass::float_ue8m0_t::bitcast(0xfe); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value (2^0) + static cutlass::float_ue8m0_t epsilon() { return cutlass::float_ue8m0_t::bitcast(0x7f); } +}; + + } // namespace std #endif @@ -1251,6 +1552,73 @@ struct numeric_limits : static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } }; + +template +struct float8_exmy_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + CUTLASS_HOST_DEVICE + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + CUTLASS_HOST_DEVICE + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + CUTLASS_HOST_DEVICE + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + CUTLASS_HOST_DEVICE + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + CUTLASS_HOST_DEVICE + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE + static type denorm_min() { return type::bitcast(0x01); } +}; + +/// Numeric limits for float_ue8m0_t +template <> +struct numeric_limits : + public float8_exmy_numeric_limits { + static bool const has_infinity = false; + static bool const is_signed = false; + + /// Minimum finite value + static cutlass::float_ue8m0_t lowest() { return cutlass::float_ue8m0_t::bitcast(0xfe); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value (2^0) + static cutlass::float_ue8m0_t epsilon() { return cutlass::float_ue8m0_t::bitcast(0x7f); } +}; + + } // namespace platform } // namespace cutlass @@ -1271,6 +1639,18 @@ cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { return cutlass::float_e4m3_t(int(x)); } + +CUTLASS_HOST_DEVICE +cutlass::float_ue4m3_t operator "" _fue4m3(long double x) { + return cutlass::float_ue4m3_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_ue4m3_t operator "" _fue4m3(unsigned long long int x) { + return cutlass::float_ue4m3_t(int(x)); +} + + CUTLASS_HOST_DEVICE cutlass::float_e5m2_t operator "" _fe5m2(long double x) { return cutlass::float_e5m2_t(float(x)); @@ -1281,4 +1661,18 @@ cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { return cutlass::float_e5m2_t(int(x)); } + +CUTLASS_HOST_DEVICE +cutlass::float_ue8m0_t operator "" _fue8m0(long double x) +{ + return cutlass::float_ue8m0_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_ue8m0_t operator "" _fue8m0(unsigned long long int x) +{ + return cutlass::float_ue8m0_t(int(x)); +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/float_subbyte.h b/include/cutlass/float_subbyte.h new file mode 100644 index 0000000000..15e539a8ad --- /dev/null +++ b/include/cutlass/float_subbyte.h @@ -0,0 +1,788 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! + \file + \brief Defines classes for FP4/FP6 datatypes +*/ +#pragma once + +#include "cutlass/arch/config.h" +#include "cutlass/float8.h" + +// FP4 types are available starting CUDA 12+ +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUDA_FP4_ENABLED 1 +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) +# define CUDA_PTX_FP4FP6_CVT_ENABLED 1 +#endif +#include "cutlass/cutlass.h" +#include "cutlass/exmy_base.h" + +#include "cute/util/type_traits.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +// FP4 and FP6 types +struct float_e2m1_t; +struct float_e3m2_t; +// E2M1: +// 2 Exponent bits with 1 Mantissa bit +// Range: +-[0,0.5,1,1.5,2,3,4,5,6] +// has_Inf: false +// has_NaN: false +// has_denorm: true +// Exponent bias (exp_bias): 1 + +struct float_e2m1_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m1_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m1_t(Base x) : Base(x) { + } +}; + +namespace detail { + +// This new type is used to select correct MMA type and TMA type. +struct float_e2m1_unpacksmem_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m1_unpacksmem_t() = default; + + CUTLASS_HOST_DEVICE + float_e2m1_unpacksmem_t(float_e2m1_unpacksmem_t const& x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_unpacksmem_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_unpacksmem_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_unpacksmem_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m1_unpacksmem_t(Base x) : Base(x) { + } +}; + +} // namespace detail + +/// Defines the size of an element in bits - specialized for float_e2m1_t +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + +CUTLASS_HOST_DEVICE +float_e2m1_t abs(float_e2m1_t const& val) { + using BaseType = typename float_e2m1_t::Base; + return float_e2m1_t(abs(BaseType{val.raw()})); +} + + +// E2M3: +// 2 Exponent bits with 3 Mantissa bit +// Range: [-7.5,+7.5] +// has_Inf: false +// has_NaN: false +// has_denorm: true +// Exponent bias (exp_bias): 1 + +struct float_e2m3_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m3_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m3_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(float_e3m2_t x); +}; + +namespace detail { + +struct float_e2m3_unpack8bits_t: public float_exmy_base { + // Used in register. + using Base = float_exmy_base; + + float_e2m3_unpack8bits_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpack8bits_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpack8bits_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpack8bits_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m3_unpack8bits_t(Base x) : Base(x) { + } +}; + +// This new type is used to select correct MMA type and TMA type. +struct float_e2m3_unpacksmem_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m3_unpacksmem_t() = default; + + CUTLASS_HOST_DEVICE + float_e2m3_unpacksmem_t(float_e2m3_unpacksmem_t const& x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpacksmem_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpacksmem_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpacksmem_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m3_unpacksmem_t(Base x) : Base(x) { + } +}; + +} // namespace detail + +/// Defines the size of an element in bits - specialized for float_e2m3_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +/// Defines the size of an element in bits - specialized for float_e2m3_unpacksmem_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +CUTLASS_HOST_DEVICE +float_e2m3_t abs(float_e2m3_t const& val) { + using BaseType = typename float_e2m3_t::Base; + return float_e2m3_t(abs(BaseType{val.raw()})); +} + +// E3M2: +// 3 Exponent bits, 2 Mantissa bits +// Range: [-28:+28] +// has_inf: false +// has_NaN: false +// has_denorm: true +// Exponent bias (exp_bias): 3 + +struct float_e3m2_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e3m2_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e3m2_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(float_e2m3_t x); +}; + +namespace detail { + +struct float_e3m2_unpack8bits_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e3m2_unpack8bits_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpack8bits_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpack8bits_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpack8bits_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e3m2_unpack8bits_t(Base x) : Base(x) { + } +}; + +// This new type is used to select correct MMA type and TMA type. +struct float_e3m2_unpacksmem_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e3m2_unpacksmem_t() = default; + + CUTLASS_HOST_DEVICE + float_e3m2_unpacksmem_t(float_e3m2_unpacksmem_t const& x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpacksmem_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpacksmem_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpacksmem_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e3m2_unpacksmem_t(Base x) : Base(x) { + } +}; + +} // namespace detail + +/// Defines the size of an element in bits - specialized for float_e3m2_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +/// Defines the size of an element in bits - specialized for float_e3m2_unpacksmem_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +CUTLASS_HOST_DEVICE +float_e3m2_t abs(float_e3m2_t const& val) { + using BaseType = typename float_e3m2_t::Base; + return float_e3m2_t(abs(BaseType{val.raw()})); +} + +/// Defines the size of an element in bits - specialized for float_e3m2_unpack8bits_t +template <> +struct sizeof_bits { + static constexpr int value = 8; +}; + +/// Defines the size of an element in bits - specialized for float_e2m3_unpack8bits_t +template <> +struct sizeof_bits { + static constexpr int value = 8; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Get the register type used in kernel +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct get_unpacked_element_type; + +template <> +struct get_unpacked_element_type { + using type = detail::float_e2m3_unpack8bits_t; +}; + +template <> +struct get_unpacked_element_type { + using type = detail::float_e3m2_unpack8bits_t; +}; +} // namespace detail +// /////////////////////////////////////////////////////////////////////////////////////////////////// +// // +// // float_e2m3_t <=> float_e3m2_t conversions +// // +// /////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +float_e2m3_t::float_e2m3_t(float_e3m2_t x) +{ + storage = convert_from_float(float(x)).storage; +} + +CUTLASS_HOST_DEVICE +float_e3m2_t::float_e3m2_t(float_e2m3_t x) +{ + storage = convert_from_float(float(x)).storage; +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////// +/// +/// Umbrella floating-point 6-bit data type : type_erased_dynamic_float6_t +/// This umbrella datatype can be enabled when a user provides a specific +/// datatype in runtime argument list. +/// +/// Currently supported runtime datatypes compatible with type_erased_dynamic_float6_t: +/// MXF8F6F4Format::E2M3 +/// MXF8F6F4Format::E3M2 +/// +/////////////////////////////////////////////////////////////// + +union type_erased_dynamic_float6_t { + cutlass::float_e2m3_t e2m3; + cutlass::float_e3m2_t e3m2; + + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e2m3_t() const { + return e2m3; + } + + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e3m2_t() const { + return e3m2; + } +}; + +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +/////////////////////////////////////////////////////////////// +/// +/// Umbrella floating-point 4-bit data type : type_erased_dynamic_float4_t +/// This umbrella datatype can be enabled when a user provides a specific +/// datatype in runtime argument list. +/// +/// Currently supported runtime datatypes compatible with type_erased_dynamic_float4_t: +/// MXF8F6F4Format::E2M1 +/// +/////////////////////////////////////////////////////////////// + +union type_erased_dynamic_float4_t { + cutlass::float_e2m1_t e2m1; + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e2m1_t() const { + return e2m1; + } +}; + +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + + +/////////////////////////////////////////////////////////////// +/// MX/NV types for float6 and float4 +/// Intended to be used in builders +/////////////////////////////////////////////////////////////// + +template +struct mx_float6_t +{ + static_assert(cute::is_same_v + || cute::is_same_v + || cute::is_same_v + , "Only float_e2m3_t, float_e3m2_t can have scale factors for MXFP6"); + using ScaleFactorType = cutlass::float_ue8m0_t; + using DataType = F6Type; +}; + +using type_erased_dynamic_mx_float6_t = mx_float6_t; + +template +struct mx_float4_t +{ + static_assert(cute::is_same_v + || cute::is_same_v + , "Only float_e2m1_t type_erased_dynamic_float4_t can have scale factors for MXFP4"); + using ScaleFactorType = cutlass::float_ue8m0_t; + using DataType = F4Type; +}; + +using type_erased_dynamic_mx_float4_t = mx_float4_t; + +template +struct nv_float4_t +{ + static_assert(cute::is_same_v + || cute::is_same_v + , "Only float_e2m1_t type_erased_dynamic_float4_t can have scale factors for NVFP4"); + using ScaleFactorType = cutlass::float_ue4m3_t; + using DataType = F4Type; +}; + +using type_erased_dynamic_nv_float4_t = nv_float4_t; + + +namespace detail { + +union type_erased_dynamic_float6_unpacksmem_t { + cutlass::detail::float_e2m3_unpacksmem_t e2m3_unpacksmem; + cutlass::detail::float_e3m2_unpacksmem_t e3m2_unpacksmem; + + CUTLASS_HOST_DEVICE + explicit operator cutlass::detail::float_e2m3_unpacksmem_t() const { + return e2m3_unpacksmem; + } + + CUTLASS_HOST_DEVICE + explicit operator cutlass::detail::float_e3m2_unpacksmem_t() const { + return e3m2_unpacksmem; + } +}; + +union type_erased_dynamic_float4_unpacksmem_t { + cutlass::detail::float_e2m1_unpacksmem_t e2m1_unpacksmem; + + CUTLASS_HOST_DEVICE + explicit operator cutlass::detail::float_e2m1_unpacksmem_t() const { + return e2m1_unpacksmem; + } +}; + +}; + +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + + +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) +namespace std { +/// Numeric limits common to all float4 types +template +struct float_subbyte_base_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = false; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + static type denorm_min() { return type::bitcast(0x01); } +}; +/// Numeric limits for float_e2m1_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m1_t lowest() { return cutlass::float_e2m1_t::bitcast(0xf); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m1_t epsilon() { return cutlass::float_e2m1_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e2m3_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m3_t lowest() { return cutlass::float_e2m3_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m3_t epsilon() { return cutlass::float_e2m3_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e3m2_t + +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e3m2_t lowest() { return cutlass::float_e3m2_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e3m2_t epsilon() { return cutlass::float_e3m2_t::bitcast(0x4); } +}; +} // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Numeric limits common to all float4 types +template +struct float_subbyte_base_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = false; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + static type denorm_min() { return type::bitcast(0x01); } +}; + +/// Forward Declaration +template +struct numeric_limits; +/// Numeric limits for float_e2m1_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m1_t lowest() { return cutlass::float_e2m1_t::bitcast(0xf); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m1_t epsilon() { return cutlass::float_e2m1_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e2m3_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m3_t lowest() { return cutlass::float_e2m3_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m3_t epsilon() { return cutlass::float_e2m3_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e3m2_t + +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e3m2_t lowest() { return cutlass::float_e3m2_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e3m2_t epsilon() { return cutlass::float_e3m2_t::bitcast(0x4); } +}; + +/// Numeric limits for float_e2m3_unpack8bits_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::detail::float_e2m3_unpack8bits_t lowest() { return cutlass::detail::float_e2m3_unpack8bits_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::detail::float_e2m3_unpack8bits_t epsilon() { return cutlass::detail::float_e2m3_unpack8bits_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e3m2_unpack8bits_t + +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::detail::float_e3m2_unpack8bits_t lowest() { return cutlass::detail::float_e3m2_unpack8bits_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::detail::float_e3m2_unpack8bits_t epsilon() { return cutlass::detail::float_e3m2_unpack8bits_t::bitcast(0x4); } +}; +} // namespace platform + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// +CUTLASS_HOST_DEVICE +cutlass::float_e2m1_t operator"" _fe2m1(long double x) +{ + return cutlass::float_e2m1_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e2m1_t operator"" _fe2m1(unsigned long long int x) +{ + return cutlass::float_e2m1_t(int(x)); +} +CUTLASS_HOST_DEVICE +cutlass::float_e2m3_t operator"" _fe2m3(long double x) +{ + return cutlass::float_e2m3_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e2m3_t operator"" _fe2m3(unsigned long long int x) +{ + return cutlass::float_e2m3_t(int(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e3m2_t operator"" _fe3m2(long double x) +{ + return cutlass::float_e3m2_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e3m2_t operator"" _fe3m2(unsigned long long int x) +{ + return cutlass::float_e3m2_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 52a4d1428b..5d3d6fca43 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -53,6 +53,12 @@ #include #endif // _MSC_VER + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +# define CUTLASS_ARCH_CREDUX_ENABLED +#endif + + namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -275,6 +281,16 @@ struct reciprocal_approximate { } }; + +template <> +struct reciprocal_approximate { + CUTLASS_HOST_DEVICE + cutlass::float_ue8m0_t operator()(cutlass::float_ue8m0_t lhs) const { + return cutlass::float_ue8m0_t::bitcast(static_cast(static_cast(254u) - lhs.storage)); + } +}; + + /// reciprocal_approximate with ftz template struct reciprocal_approximate_ftz : reciprocal_approximate @@ -586,6 +602,33 @@ struct guarded_multiply_add_relu0 { } }; + +/// Fused and-popc-add +template +struct and_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A and_result = a & b; + +#if defined(__CUDA__ARCH__) + int popc_result = __popc(and_result); + + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __popc(static_cast(and_result >> 32)); + } + +#else + int popc_result = __builtin_popcount(and_result); + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __builtin_popcount(static_cast(and_result >> 32)); + } + +#endif + + return C(popc_result) + c; + } +}; + /// Fused multiply-add template struct and_add { @@ -596,6 +639,33 @@ struct and_add { }; + +/// Fused xor-popc-add +template +struct xor_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A xor_result = a ^ b; + +#if defined(__CUDA__ARCH__) + int popc_result = __popc(xor_result); + + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __popc(static_cast(xor_result >> 32)); + } + +#else + int popc_result = __builtin_popcount(xor_result); + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __builtin_popcount(static_cast(xor_result >> 32)); + } + +#endif + + return C(popc_result) + c; + } +}; + /// Fused multiply-add template struct xor_add { @@ -605,6 +675,43 @@ struct xor_add { } }; + +/// Fused or-popc-add +template +struct or_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A or_result = a | b; + +#if defined(__CUDA__ARCH__) + int popc_result = __popc(or_result); + + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __popc(static_cast(or_result >> 32)); + } + +#else + int popc_result = __builtin_popcount(or_result); + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __builtin_popcount(static_cast(or_result >> 32)); + } + +#endif + + return C(popc_result) + c; + } +}; + + +/// Fused multiply-add +template +struct or_add { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b, T const &c) const { + return ((a | b) + c); + } +}; + namespace detail { // Whether namespace-unqualified conj(t) for t of type T is @@ -886,6 +993,78 @@ struct is_atomic> : platform::true_type {}; template struct is_atomic> : platform::true_type {}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +/// Parallel Synchronization and Communication Instructions +template +struct redux_abs_max_nan_propagation_sync_warp; + +template <> +struct redux_abs_max_nan_propagation_sync_warp { + CUTLASS_DEVICE + float operator()(float const &lhs) const { +#if defined(CUTLASS_ARCH_CREDUX_ENABLED) + float result; + asm volatile("redux.sync.max.abs.NaN.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(lhs)); + return result; +#elif defined(__CUDA_ARCH__) + cutlass::maximum max_op; + int shuffle_width = 32; + float abs_max = cutlass::absolute_value_op{}(lhs); + CUTLASS_PRAGMA_UNROLL + for(int offset = shuffle_width / 2; offset > 0; offset /= 2) { + float value = __shfl_down_sync(0xffffffff, abs_max, offset, shuffle_width); + abs_max = max_op(abs_max,value); + } + // Broadcast the maximum to all threads participating in the reduction. + abs_max = __shfl_sync(0xffffffff, abs_max, 0, shuffle_width); + return abs_max; +#else + CUTLASS_UNUSED(lhs); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + +template +struct redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31; + +template <> +struct redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31{ + CUTLASS_DEVICE + float operator()(float const &max) const { +#if defined(CUTLASS_ARCH_CREDUX_ENABLED) + int half_warp_idx = threadIdx.x / (NumThreadsPerWarp / 2); + bool first_half_threads = (half_warp_idx % 2) == 0; + float value0 = first_half_threads ? max : 0; + float v0 = cutlass::redux_abs_max_nan_propagation_sync_warp{}(value0); + + float value1 = !first_half_threads ? max : 0; + float v1 = cutlass::redux_abs_max_nan_propagation_sync_warp{}(value1); + return first_half_threads ? v0: v1; + +#elif defined(__CUDA_ARCH__) + float abs_max = cutlass::absolute_value_op{}(max); + cutlass::maximum max_op; + constexpr int shuffle_width = 16; + CUTLASS_PRAGMA_UNROLL + for(int offset = shuffle_width/2; offset > 0; offset /= 2) { + float value = __shfl_down_sync(0xffffffff, abs_max, offset, shuffle_width); + abs_max = max_op(abs_max,value); + } + // Broadcast the maximum to all threads participating in the reduction. + abs_max = __shfl_sync(0xffffffff, abs_max, 0, shuffle_width); + return abs_max; +#else + CUTLASS_UNUSED(max); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for nvcuda::wmma::fragment diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl new file mode 100644 index 0000000000..b4927883fa --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -0,0 +1,782 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +// + +// + +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int carveout_bytes +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout stage_count) { + // For Mxf8f6f4 sub-bytes, ElementA/B will be passed in as uint8_t + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) + // 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed) + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); + constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{})); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +template +constexpr auto +sm100_cluster_shape_to_tma_atom_SFB(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) { + static_assert(cute::rank(cluster_shape_mnk) == 3); + if constexpr (cute::size(atom_thr_id) == 2) { + // Always could use multicast feature for SFB with 2cta MMA. + return cute::SM100_TMA_2SM_LOAD_MULTICAST{}; + } + else if constexpr (size(atom_thr_id) == 1) { + return detail::sm90_cluster_shape_to_tma_atom(cute::size<0>(cluster_shape_mnk)); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported Configuration for SM100 TMA"); + } +} + +namespace blockscaled { + +enum class BlockScaledInstr { + MXF4_NVF4, + MXF4F6F8 +}; + +template +struct blockscaled_type {}; + +template +struct blockscaled_type> { + using sf_type = SF; + using data_type = T; + static constexpr uint32_t SfVectorSize = detail::find_vector_size(); +}; + +template +struct blockscaled_type>> { + using sf_type = SF; + using data_type = T; + static constexpr uint32_t SfVectorSize = SfVectorSize_; +}; + +template +struct blockscaled_type> { + using sf_type = cutlass::float_ue8m0_t; + using data_type = T; + static constexpr uint32_t SfVectorSize = 32; +}; + +template +struct blockscaled_type> { + using sf_type = cutlass::float_ue8m0_t; + using data_type = T; + static constexpr uint32_t SfVectorSize = 32; +}; + +template +struct blockscaled_type> { + using sf_type = cutlass::float_ue4m3_t; + using data_type = T; + static constexpr uint32_t SfVectorSize = 16; +}; +template +struct blockscaled_type> { + using sf_type = cutlass::float_ue8m0_t; + using data_type = T; + static constexpr uint32_t SfVectorSize = 32; +}; + +template < + class KernelScheduleType, + class ElementPairA, class ElementPairB, + UMMA::Major UmmaMajorA, UMMA::Major UmmaMajorB +> +CUTLASS_HOST_DEVICE +static constexpr bool +check_input_datatypes() { + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + constexpr uint32_t SfVectorSizeA = detail::blockscaled::blockscaled_type::SfVectorSize; + constexpr uint32_t SfVectorSizeB = detail::blockscaled::blockscaled_type::SfVectorSize; + + auto is_auto_instr_selection_policy = [&]() { + return ((cute::is_same_v) || + (cute::is_same_v) || + (cute::is_same_v) || + (cute::is_same_v) || + (cute::is_same_v) || + (cute::is_same_v) || + (cute::is_same_v)); + }; + + static_assert(cute::is_same_v, "Scale factor types for A and B should be the same."); + static_assert((SfVectorSizeA == SfVectorSizeB), "Scale factor vector size for A and B should be the same."); + if constexpr ((SfVectorSizeA == 0) || (SfVectorSizeB == 0)) { + static_assert(!is_auto_instr_selection_policy(), "Auto instr selection isn't valid if scale factor vector size can't be determined from the types"); + } + + static_assert(cute::is_same_v + || cute::is_same_v, "Incorrect scale factor type"); + + if constexpr (((sizeof_bits_v == 4 || sizeof_bits_v == 6 || sizeof_bits_v == 8) && + (sizeof_bits_v == 4 || sizeof_bits_v == 6 || sizeof_bits_v == 8) ) && // A and B are 4, 6, or 8 bit types and + (!(sizeof_bits_v == 4 && sizeof_bits_v == 4) ) // A and B are not both 4 bit types + ) { + /////////////////////////////////////////////////////////////////////// + // Mixed Precision FP4, FP6, FP8 case. -> MX_F4F6F8 instructions + /////////////////////////////////////////////////////////////////////// + // 1. Check Scale factor data type + static_assert(cute::is_same_v, "MX_F4F6F8 only supports ue8m0 SF type"); + // 2. Check whether A and B type combinations are valid or not + static_assert( + ( // If runtime datatypes are used, then both A and B should be runtime data type + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) && + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) + ) || + ( // Valid (explicit) A and B type pairs + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) && + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) + ), "Incorrect types for A and B for MX_F4F6F8" + ); + // 3. Check Scale factor vector size is valid. + // Only SfVectorSize = 32 is allowed. + static_assert((SfVectorSizeA == 32) && (SfVectorSizeB == 32), "Incorrect SfVectorSize for MX_F4F6F8 is deduced. SfVectorSize should be 32."); + // 4. Check the kernel policy. Kernel policy should be either auto or *MXf8f6f4* + static_assert((cute::is_base_of_v || + cute::is_base_of_v || + is_auto_instr_selection_policy()), "Incorrect Kernel Schedule Policy for Mx_F4F6F8 type inputs."); + + return true; + } + else if constexpr ((sizeof_bits_v == 4 && sizeof_bits_v == 4)) { + /////////////////////////////////////////////////////////////////////// + // A and B are both 4 bit types + // There are multiple block scaled tcgen05.mma instructions supporting F4 types. + /////////////////////////////////////////////////////////////////////// + + // 1. Check Scale factor data type + static_assert(cute::is_same_v + || cute::is_same_v + , "MXNV_F4 supports ue8m0 and ue4m3 SF types"); + // 2. Check whether A and B type combinations are valid or not + static_assert( + ( // If runtime datatypes are used, then both A and B should be runtime data type + cute::is_same_v && + cute::is_same_v + ) || + ( // Valid (explicit) A and B type pairs + ( + cute::is_same_v + ) && + ( + cute::is_same_v + ) + ), "Incorrect types for A and B for MXNV_F4"); + // 3. Skip checking the scale factor vector size. Will be checked later for specific Kernel Schedule policies. + // 4. Check the kernel policy. + static_assert((cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + is_auto_instr_selection_policy()), "Incorrect Kernel Schedule Policy for F4 type inputs."); + // If a policy is specified, do more checks + if constexpr (cute::is_base_of_v + || cute::is_base_of_v + ) { + // Perform additional checks. Only subset of FP4 and scale factor types are supported. + static_assert(cute::is_same_v, "MX_F4F6F8 only supports ue8m0 SF type"); + static_assert((cute::is_same_v && cute::is_same_v) || + (cute::is_same_v && cute::is_same_v), "Incorrect types for A and B for MX_F4F6F8"); + static_assert((SfVectorSizeA == 32) && (SfVectorSizeB == 32), "Incorrect SfVectorSize for MX_F4F6F8 is deduced. SfVectorSize should be 32."); + return true; + } + else if constexpr (cute::is_base_of_v + || cute::is_base_of_v + ) { + static_assert((UmmaMajorA == UMMA::Major::K && UmmaMajorB == UMMA::Major::K), "MX/NV_F4 only supports RowMajor A, and ColMajorB"); + static_assert(detail::find_vector_size() == SfVectorSizeA, "Kernel Schedule policy doesn't match the scale factor vector size."); + return true; + } + else { // auto policy + // If the scale factor type is ue4m3 or the scale factor vector size is 16 -> only MXF4_NVF4 instruction can support it + // For MXF4_NVF4, the layouts should be RowMajor A, and ColMajorB + static_assert(is_auto_instr_selection_policy(), "Kernel Schedule policy should be auto"); + if constexpr (SfVectorSizeA == 16 || SfVectorSizeB == 16 + || cute::is_same_v + ) { // Only MXF4NVF4 can support these types + static_assert((UmmaMajorA == UMMA::Major::K && UmmaMajorB == UMMA::Major::K), "NV_F4 only supports RowMajor A, and ColMajorB"); + return true; + } + return true; + } + } + else { + return false; + } + return false; +} + +template < + class TileShape_MNK, // (MmaAtomShape_M, MmaAtomShape_N, CtaTileShapeK) + class ClusterShape_MNK, + class KernelScheduleType +> +CUTLASS_HOST_DEVICE +static constexpr bool +is_2sm() { + // 2SM kernel schedule is requested + if constexpr (cute::is_base_of_v) { return true; } + // 1SM kernel schedule is requested + else if constexpr (cute::is_base_of_v) { return false; } + // auto schedule is used. + else { + if constexpr (!cute::is_static_v) { + // If the cluster shape is dynamic, we can't guarantee 2x1. Default to 1sm. + // If tile shape M is 256, throw an error. M=256 is only supported by 2SM instructions. + static_assert(get<0>(TileShape_MNK{}) != 256, "If M=256, auto policy can't create 2sm kernels. Specify a 2SM policy"); + return false; + } + else if constexpr (cute::is_static_v && cute::get<0>(ClusterShape_MNK{}) % 2 == 0) { + // We need to check the TileShape + if constexpr (get<0>(TileShape_MNK{}) == 256) { + return true; + } + else if constexpr (get<0>(TileShape_MNK{}) == 128) { + return false; + } + else { + static_assert(get<0>(TileShape_MNK{}) == 0, "Unsupported M dimension for TileShape_MNK."); + } + } + else { return false;} + } +} + +template < + class ElementPairA, + class ElementPairB, + class ElementAccumulator, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class KernelScheduleType +> +CUTLASS_HOST_DEVICE +static constexpr auto +select_instr() { + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + constexpr uint32_t SfVectorSizeA = detail::blockscaled::blockscaled_type::SfVectorSize; + constexpr uint32_t SfVectorSizeB = detail::blockscaled::blockscaled_type::SfVectorSize; + constexpr int SFVectorSize = SfVectorSizeA > SfVectorSizeB ? SfVectorSizeA : SfVectorSizeB; + using ElementSF = ElementSFA; + + if constexpr (cute::is_base_of_v + || cute::is_base_of_v + ) { + return detail::blockscaled::BlockScaledInstr::MXF4F6F8; + } + else if constexpr (cute::is_base_of_v + || cute::is_base_of_v + ) { + return detail::blockscaled::BlockScaledInstr::MXF4_NVF4; + } + else { + // Auto scheduling + if constexpr ((sizeof_bits_v >= 6 && sizeof_bits_v <= 8) && + (sizeof_bits_v >= 6 && sizeof_bits_v <= 8)) { + // These types can only be supported by MX_F8F6F4 instruction + static_assert(SFVectorSize == 32, "Incorrect SF vector size"); + return detail::blockscaled::BlockScaledInstr::MXF4F6F8; + } + else if constexpr (( sizeof_bits_v == 4 && (sizeof_bits_v == 6 || sizeof_bits_v == 8)) || + ((sizeof_bits_v == 6 || sizeof_bits_v == 8) && sizeof_bits_v == 4)) { + // Fp4 can be mixed with FP6, Fp8 with Mxf8f6f4 only + return detail::blockscaled::BlockScaledInstr::MXF4F6F8; + } + else if constexpr (sizeof_bits_v == 4 && sizeof_bits_v == 4) { + // Both A and B are 4bits + if constexpr (UmmaMajorA == UMMA::Major::K && UmmaMajorB == UMMA::Major::K) { + // MXF4_NVF4 possible + return detail::blockscaled::BlockScaledInstr::MXF4_NVF4; + } + else { + static_assert(SFVectorSize == 32, "Incorrect SF vector size"); + static_assert( cute::is_same_v && + (cute::is_same_v && cute::is_same_v || + cute::is_same_v && cute::is_same_v), + "Only MXF4 support with non-TN and Mxf8f6f4"); + return detail::blockscaled::BlockScaledInstr::MXF4F6F8; + } + } + } +} + +} // namespace blockscaled + +template < + class ElementPairA, + class ElementPairB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + detail::blockscaled::BlockScaledInstr Instr, + class KernelScheduleType +> +constexpr auto +sm100_make_blockscaled_1sm_trivial_tiled_mma() { + // For MMA_1sm atoms, the MMA's AtomLayout is same as the ClusterShape + using AtomLayout_MNK = Layout; + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N == 128 || N == 192 || N == 256, "Invalid TileShape_N."); + + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + constexpr uint32_t SfVectorSizeA = detail::blockscaled::blockscaled_type::SfVectorSize; + [[maybe_unused]] constexpr uint32_t SfVectorSizeB = detail::blockscaled::blockscaled_type::SfVectorSize; + + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + + using ElementSF = ElementSFA; + + if constexpr (Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8) { + + return make_tiled_mma(cute::SM100_MMA_MXF8F6F4_SS{}); + } + else if constexpr (Instr == detail::blockscaled::BlockScaledInstr::MXF4_NVF4) { + constexpr int SFVectorSize = SfVectorSizeA; + return make_tiled_mma(cute::SM100_MMA_MXF4_SS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM100 collective builder."); + } +} + +template < + class ElementPairA, + class ElementPairB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + detail::blockscaled::BlockScaledInstr Instr, + class KernelScheduleType +> +constexpr auto +sm100_make_blockscaled_2sm_trivial_tiled_mma() { + + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 256, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N == 128 || N == 192 || N == 256, "Invalid TileShape_N."); + + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + constexpr uint32_t SfVectorSizeA = detail::blockscaled::blockscaled_type::SfVectorSize; + [[maybe_unused]] constexpr uint32_t SfVectorSizeB = detail::blockscaled::blockscaled_type::SfVectorSize; + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + + using ElementSF = ElementSFA; + + if constexpr (Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8) { + return make_tiled_mma(cute::SM100_MMA_MXF8F6F4_2x1SM_SS{}); + } + else if constexpr (Instr == detail::blockscaled::BlockScaledInstr::MXF4_NVF4) { + constexpr int SFVectorSize = SfVectorSizeA > SfVectorSizeB ? SfVectorSizeA : SfVectorSizeB; + return make_tiled_mma(cute::SM100_MMA_MXF4_2x1SM_SS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM100 collective builder."); + } +} + + +template < + class ElementPairA, + class ElementPairB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + detail::blockscaled::BlockScaledInstr Instr, + class KernelScheduleType, + bool Is2SM +> +struct TrivialBlockscaledMma {}; + +template < + class ElementPairA, + class ElementPairB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + detail::blockscaled::BlockScaledInstr Instr, + class KernelScheduleType +> +struct TrivialBlockscaledMma < + ElementPairA, + ElementPairB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + UmmaMajorA, + UmmaMajorB, + Instr, + KernelScheduleType, + true /*Is2SM*/> { + using type = decltype(sm100_make_blockscaled_2sm_trivial_tiled_mma()); + }; + +template < + class ElementPairA, + class ElementPairB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + detail::blockscaled::BlockScaledInstr Instr, + class KernelScheduleType +> +struct TrivialBlockscaledMma< + ElementPairA, + ElementPairB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + UmmaMajorA, + UmmaMajorB, + Instr, + KernelScheduleType, + false /*Is2SM*/> { + using type = decltype(sm100_make_blockscaled_1sm_trivial_tiled_mma()); +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + // Blockscaled Gemm + (cute::is_base_of_v || + cute::is_same_v) + && + // Alignment check + detail::sm1xx_blockscaled_gemm_is_aligned::data_type, + AlignmentA, + typename detail::blockscaled::blockscaled_type::data_type, + AlignmentB, + KernelScheduleType>()>> +{ + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + using ElementSF = ElementSFA; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); + + static constexpr bool is_2sm = detail::blockscaled::is_2sm(); + static constexpr auto Instr = detail::blockscaled::select_instr(); + + using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma::type; + + static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; + + static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B()), "Only Mxf8f6f4 supports non-K major inputs"); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + + static_assert(detail::sm100_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + + static constexpr uint32_t SFVectorSize = TiledMma::SFVecSize; + + // Basic storage block for new Scaling Factor Layouts + using AtomThrID = typename TiledMma::AtomThrID; + using Sm100BlkScaledConfig = cutlass::detail::Sm100BlockScaledConfig; + + using ElementAMma_SmemAllocType = cute::conditional_t; + using ElementBMma_SmemAllocType = cute::conditional_t; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{})); + + // + // Construct SMEM layout (SmemLayoutAtom) for A and SFA + // + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, BlockTileA_M, BlockTileA_K>()); + + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_MN = typename Sm100BlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm100BlkScaledConfig::Blk_SF; + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + using SmemLayoutAtomSFA = decltype(Sm100BlkScaledConfig::deduce_smem_layoutSFA(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + + // + // Construct SMEM layout (SmemLayoutAtom) for B and SFB + // + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + using SmemLayoutAtomSFB = decltype(Sm100BlkScaledConfig::deduce_smem_layoutSFB(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); + + // + // Construct Strides for A, SFA, B, and SFB + // + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm100BlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB = decltype(Sm100BlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA *>; + using LayoutSFB = cute::conditional_t, InternalLayoutSFB, InternalLayoutSFB *>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{})); + + static constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); + static constexpr uint32_t AccumulatorPipelineStageCount = (MMA_N == 256) ? 1 : 2; + // Grouped GEMM (where Stride type is Stride*) does not use CLC based scheduler. + static constexpr uint32_t SchedulerPipelineStageCount = cute::is_same_v ? 3 : 1; + static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v; + static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< + ClusterShape_MNK, + AccumulatorPipelineStageCount, + SchedulerPipelineStageCount, + detail::CLCResponseSize, + IsArrayOfPointersGemm, + 4 // 4 Tensor maps for A, SFA, B and SFB + >::KernelSmemCarveout; + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); + + using DispatchPolicy = + cute::conditional_t, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + cute::tuple, + StridePairA, + cute::tuple, + StridePairB, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + void, + cute::identity, + GmemTiledCopyPairB, + SmemLayoutAtomsB, + void, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl new file mode 100644 index 0000000000..464ffe89a0 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -0,0 +1,572 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// +#pragma once + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" // KernelSchedule1Sm, KernelSchedule2Sm +#include "cutlass/gemm/collective/builders/sm90_common.inl" // detail::sm90_cluster_shape_to_tma_atom() +#include "cutlass/numeric_types.h" // all numeric types +#include "cutlass/detail/dependent_false.hpp" // detail::dependent_false +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/layout.hpp" // cutlass::detail::get_input_alignment_bits() +#include "cutlass/layout/matrix.h" // cutlass::layout::RowMajor, cutlass::layout::ColumnMajor +#include "cutlass/fast_math.h" // cutlass::round_up, cutlass::const_max +#include "cutlass/arch/arch.h" + +#include "cute/atom/mma_traits_sm100.hpp" // UMMA::Layout_MN_SW* +#include "cute/atom/copy_traits_sm100_tma.hpp" // SM100_TMA_*SM_LOAD_* +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::Major +#include "cute/arch/mma_sm100_umma.hpp" // SM100_*MMA_SS_* +#include "cute/numeric/integral_constant.hpp" // is_static_v, cute::integral_constant +#include "cute/util/type_traits.hpp" // cute::alignment_of_v + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +// Forward Declaration +struct KernelScheduleAuto; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// Some named constants +// +constexpr int sm100_smem_capacity_bytes = cutlass::arch::sm100_smem_capacity_bytes; +constexpr int CLCResponseSize = + sizeof(typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100,1>::CLCResponse{}); + +// Maps input element to umma element +template +constexpr auto +sm100_kernel_input_element_to_mma_input_element() { + if constexpr (cute::is_same_v) { + return cutlass::tfloat32_t{}; + } + else if constexpr (cute::is_same_v && IsF8F6F4) { + return cutlass::detail::float_e2m1_unpacksmem_t{}; + } + else if constexpr (cute::is_same_v && IsF8F6F4) { + return cutlass::detail::float_e3m2_unpacksmem_t{}; + } + else if constexpr (cute::is_same_v && IsF8F6F4) { + return cutlass::detail::float_e2m3_unpacksmem_t{}; + } + else if constexpr (cute::is_same_v && IsF8F6F4) { + return cutlass::detail::type_erased_dynamic_float4_unpacksmem_t{}; + } + else if constexpr (cute::is_same_v && IsF8F6F4) { + return cutlass::detail::type_erased_dynamic_float6_unpacksmem_t{}; + } + else { + return Element{}; + } +} + +// Maps 2.x A matrix layout tag to respective UMMA major mode enum +template +constexpr cute::UMMA::Major +tag_to_umma_major_A() { + using LayoutA = cute::remove_pointer_t; + if constexpr (cute::is_same_v) { + return cute::UMMA::Major::K; + } + else if constexpr (cute::is_same_v) { + return cute::UMMA::Major::MN; + } + else if constexpr (cutlass::detail::is_major<0, LayoutA>()) { + return cute::UMMA::Major::MN; + } + else if constexpr (cutlass::detail::is_major<1, LayoutA>()) { + return cute::UMMA::Major::K; + } + else { + static_assert(sizeof(LayoutA) == 0, "Invalid layout."); + } +} + +// Maps 2.x B matrix layout tag to respective UMMA major mode enum +template +constexpr cute::UMMA::Major +tag_to_umma_major_B() { + using LayoutB = cute::remove_pointer_t; + if constexpr (cute::is_same_v) { + return cute::UMMA::Major::MN; + } + else if constexpr (cute::is_same_v) { + return cute::UMMA::Major::K; + } + else if constexpr (cutlass::detail::is_major<0, LayoutB>()) { + return cute::UMMA::Major::MN; + } + else if constexpr (cutlass::detail::is_major<1, LayoutB>()) { + return cute::UMMA::Major::K; + } + else { + static_assert(sizeof(LayoutB) == 0, "Invalid layout."); + } +} + +// Helper for SS UMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest UMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +sm100_smem_selector() { + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + if constexpr (major == cute::UMMA::Major::MN) { + // Handle the special case for F32 NT kernels + if constexpr ((sizeof(ElementType) == 4)) { + static_assert(BLK_MN0 % size<0>(UMMA::Layout_MN_SW128_32B_Atom{}) == 0, "for mn-major tf32 operands, SW128_32B is the only available smem layout"); + return UMMA::Layout_MN_SW128_32B_Atom{}; + } + else { + // All other data types are handled as SM90 + if constexpr (BLK_MN0 % size<0>(UMMA::Layout_MN_SW128_Atom{}) == 0) { + return UMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(UMMA::Layout_MN_SW64_Atom{}) == 0) { + return UMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(UMMA::Layout_MN_SW32_Atom{}) == 0) { + return UMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(UMMA::Layout_MN_INTER_Atom{}) == 0) { + return UMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(UMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(UMMA::Layout_MN_INTER_Atom{})"); + } + } + } + else if constexpr (major == cute::UMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(UMMA::Layout_K_SW128_Atom{}) == 0) { + return UMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(UMMA::Layout_K_SW64_Atom{}) == 0) { + return UMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(UMMA::Layout_K_SW32_Atom{}) == 0) { + return UMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(UMMA::Layout_K_INTER_Atom{}) == 0) { + return UMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(UMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(UMMA::Layout_K_INTER_Atom{})"); + } + } +} + +template +constexpr auto +sm100_cluster_shape_to_tma_atom_A(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) { + static_assert(cute::rank(cluster_shape_mnk) == 3); + constexpr bool IsDynamicCluster = not cute::is_static_v; + + if constexpr (cute::size(atom_thr_id) == 2) { + if constexpr (!IsDynamicCluster) { + static_assert(cute::size<0>(cluster_shape_mnk) % 2 == 0, "Cluster shape not divisible by MMA size"); + if constexpr (cute::size<1>(cluster_shape_mnk) == 1) { + return cute::SM100_TMA_2SM_LOAD{}; + } + else { + return cute::SM100_TMA_2SM_LOAD_MULTICAST{}; + } + } + else { + return cute::SM100_TMA_2SM_LOAD_MULTICAST{}; + } + } + else if constexpr (size(atom_thr_id) == 1) { + if constexpr (!IsDynamicCluster) { + return detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(cluster_shape_mnk)); + } + else { + // In the case of dynamic cluster, multicast decision is not known at compile time. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + return detail::sm90_cluster_shape_to_tma_atom(cute::Int<2>{}); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported Configuration for SM100 TMA"); + } +} + +template +constexpr auto +sm100_cluster_shape_to_tma_atom_B(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) { + static_assert(cute::rank(cluster_shape_mnk) == 3); + constexpr bool IsDynamicCluster = not cute::is_static_v; + + if constexpr (cute::size(atom_thr_id) == 2) { + if constexpr (!IsDynamicCluster) { + static_assert(cute::size<0>(cluster_shape_mnk) % 2 == 0, "Cluster shape not divisible by MMA size"); + if constexpr (cute::size<0>(cluster_shape_mnk) == 2) { + return cute::SM100_TMA_2SM_LOAD{}; + } + else { + return cute::SM100_TMA_2SM_LOAD_MULTICAST{}; + } + } + else { + return cute::SM100_TMA_2SM_LOAD_MULTICAST{}; + } + } else if constexpr (size(atom_thr_id) == 1) { + if constexpr (!IsDynamicCluster) { + return detail::sm90_cluster_shape_to_tma_atom(cute::size<0>(cluster_shape_mnk)); + } + else { + // In the case of dynamic cluster, multicast decision is not known at compile time. + // A multicast instruction is forced by passing a cute::Int<2>{} to this helper. + return detail::sm90_cluster_shape_to_tma_atom(cute::Int<2>{}); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported Configuration for SM100 TMA"); + } +} + +template +constexpr uint32_t find_vector_size() { + if constexpr (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) { + return 16; + } + else { + return 32; + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAMmaccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + UMMA::ScaleIn ANeg = UMMA::ScaleIn::One, + UMMA::ScaleIn BNeg = UMMA::ScaleIn::One +> +constexpr auto +sm100_make_1sm_trivial_tiled_mma() { + + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 64 || M == 128, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 8 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v) { + static_assert(cute::is_same_v, "ElementAMma and ElementBMma must match."); + return make_tiled_mma(cute::SM100_MMA_TF32_SS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + static_assert(cute::is_same_v, "ElementAMma and ElementBMma must match."); + return make_tiled_mma(cute::SM100_MMA_F16BF16_SS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_S8_SS{}); + } + else if constexpr (cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + ) { + + return make_tiled_mma( + cute::MMA_Traits< + cute::SM100_MMA_F8F6F4_SS, + ElementAMma, + ElementBMma, + ElementAMmaccumulator, + cute::C, + cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant + >{} + ); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM100 collective builder."); + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAMmaccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + UMMA::ScaleIn ANeg = UMMA::ScaleIn::One, + UMMA::ScaleIn BNeg = UMMA::ScaleIn::One +> +constexpr auto +sm100_make_2sm_trivial_tiled_mma() { + + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + // Do not allow a tiled MMA N mode > 1, as that is not reasonable. + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 8 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v) { + static_assert(cute::is_same_v, "ElementAMma and ElementBMma must match."); + return make_tiled_mma(cute::SM100_MMA_TF32_2x1SM_SS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + static_assert(cute::is_same_v, "ElementAMma and ElementBMma must match."); + return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_SS{}); + } + else if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_S8_2x1SM_SS{}); + } + else if constexpr (cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + ) { + + return make_tiled_mma( + cute::MMA_Traits< + cute::SM100_MMA_F8F6F4_2x1SM_SS, + ElementAMma, + ElementBMma, + ElementAMmaccumulator, + cute::C, + cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant + >{} + ); + + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM100 collective builder."); + } +} + +// For new MMA construction and partitioning that supports both dynamic and static cluster shape. +// Used in conjunction with make_tma_atom_(A|B)_sm100 +// TileShape_MNK is always static and has shape (MmaAtomShapeM, MmaAtomShapeN, TileK) +// ClusterShape_MNK can be dynamic or static. +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class KernelScheduleType, + UMMA::ScaleIn ANeg = UMMA::ScaleIn::One, + UMMA::ScaleIn BNeg = UMMA::ScaleIn::One +> +constexpr auto +sm100_make_trivial_tiled_mma() { + // MMA_2SM requested + if constexpr (cute::is_base_of_v ) { + return sm100_make_2sm_trivial_tiled_mma(); + } + // MMA_1SM requested + else if constexpr (cute::is_base_of_v ) { + return sm100_make_1sm_trivial_tiled_mma(); + } + // Auto scheduling requested + else if constexpr (cute::is_same_v) { + // Static cluster + if constexpr (cute::is_static_v) { + // For MMA_2SM we need a cluster shape that is multiple of 2x1 + // and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM + if constexpr (cute::size<0>(ClusterShape_MNK{}) % 2 == 0 && + cute::size<0>(TileShape_MNK{}) % 128 == 0) { + return sm100_make_2sm_trivial_tiled_mma(); + } + else { + return sm100_make_1sm_trivial_tiled_mma(); + } + // Dynamic cluster shape means we cannot assume we can use 2SM MMA + } + else { + return sm100_make_1sm_trivial_tiled_mma(); + } + } +} + + +/** + * @brief Check for U4_UNPACK_U8, U6_UNPACK_U8 alignment requirement + * + * @tparam TileShape_MNK (MmaAtomShape_M, MmaAtomShape_N, TileShape_K) + * @tparam ClusterShape_MNK (cluster_M, cluster_N, cluster_K) + * @tparam KernelScheduleType Builder tag + */ +template< + class ElementAMma, + class ElementBMma, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class KernelScheduleType, + bool Is2sm +> +constexpr bool sm100_gemm_check_for_f8f6f4_mix8bit_requirement(){ + + + [[maybe_unused]] constexpr int TileShape_M = Is2sm ? size<0>(TileShape_MNK{}) / 2 : size<0>(TileShape_MNK{}); + [[maybe_unused]] constexpr int TileShape_N = size<1>(TileShape_MNK{}); + [[maybe_unused]] constexpr int TileShape_K = size<2>(TileShape_MNK{}); + + constexpr bool is_b_unpack_f4_f6 = cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v; + constexpr bool is_a_unpack_f4_f6 = cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v; + + [[maybe_unused]] constexpr bool is_b_n_major = UmmaMajorB == UMMA::Major::MN; + [[maybe_unused]] constexpr bool is_b_k_major = !is_b_n_major; + [[maybe_unused]] constexpr bool is_a_m_major = UmmaMajorA == UMMA::Major::MN; + [[maybe_unused]] constexpr bool is_a_k_major = !is_a_m_major; + + // 2SM + if constexpr (Is2sm) { + constexpr bool valid_a = !is_a_unpack_f4_f6 || (is_a_k_major ? + TileShape_K % 128 == 0 : + TileShape_M % 128 == 0); + + constexpr bool valid_b = !is_b_unpack_f4_f6 || (is_b_n_major ? + TileShape_N % 256 == 0: + TileShape_K % 128 == 0); + return valid_a && valid_b; + } + // 1SM + else { + constexpr bool valid_a = !is_a_unpack_f4_f6 || (is_a_k_major ? + TileShape_K % 128 == 0 : + TileShape_M % 128 == 0); + + constexpr bool valid_b = !is_b_unpack_f4_f6 || (is_b_n_major ? + TileShape_N % 128 == 0 : + TileShape_K % 128 == 0); + + return valid_a && valid_b; + } +} + +template +constexpr bool +sm1xx_gemm_is_aligned() { + // Only support dense gemm alignment check + constexpr bool is_f6f4_subbytes = cute::sizeof_bits_v < 8 || cute::sizeof_bits_v < 8; + + return ((cute::sizeof_bits_v * AlignmentA) % cutlass::detail::get_input_alignment_bits() == 0) && + ((cute::sizeof_bits_v * AlignmentB) % cutlass::detail::get_input_alignment_bits() == 0); +} + +template +constexpr bool +sm1xx_blockscaled_gemm_is_aligned() { + // Only support blocksscaled gemm alignment check + constexpr bool is_f6f4_subbytes = (cute::sizeof_bits_v < 8 || cute::sizeof_bits_v < 8) && + (cute::is_base_of_v + ); + + return ((cute::sizeof_bits_v * AlignmentA) % cutlass::detail::get_input_alignment_bits() == 0) && + ((cute::sizeof_bits_v * AlignmentB) % cutlass::detail::get_input_alignment_bits() == 0); +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl b/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl new file mode 100644 index 0000000000..82ae5395de --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +namespace cutlass::gemm::collective::detail { + +template< + class ClusterShape_MNK, + int AccumulatorPipelineStageCount, + int SchedulerPipelineStageCount, + int CLCResponseSize, + bool IsArrayOfPointersGemm, + int NumTensorMaps=2 +> +struct Sm100DenseGemmTmaUmmaCarveout { + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // LoadOrderBarrier = OrderedSequenceBarrier<1,2> + static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = SchedulerPipelineStageCount * sizeof(uint32_t); + // Tensormap Storage + static constexpr auto TensorMapStorage = + IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * NumTensorMaps /* for A and B */ : + 0; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + LoadOrderBarrierStorage + + TmemDeallocStorage + + CLCThrottlePipelineStorage + + CLCResponseStorage + + TmemBasePtrsStorage + + TensorMapStorage + ); +}; + +template +struct Sm100SparseGemmTmaUmmaCarveout { + + // * GemmUniversal::SharedStorage::PipelineStorage + // LoadOrderBarrier = OrderedSequenceBarrier<1,2> + static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); + // CLCPipelineStorage = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Epilogue Throttle + static constexpr auto EpilogueThrottleStorage = sizeof(arch::ClusterBarrier); + + static constexpr auto PipelineStorage = static_cast(cutlass::round_up( + cutlass::round_up(LoadOrderBarrierStorage, 16) + + cutlass::round_up(CLCPipelineStorage, 16) + + cutlass::round_up(AccumulatorPipelineStorage, 16) + + cutlass::round_up(CLCThrottlePipelineStorage, 16) + + cutlass::round_up(TmemDeallocStorage, 8) + + cutlass::round_up(EpilogueThrottleStorage, 8), + 16)); + + // * GemmUniversal::SharedStorage::Others + // CLC (scheduler) response + static constexpr auto CLCQueryResponseStorage = SchedulerPipelineStageCount * CLCResponseSize; + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + + static constexpr auto OtherStorage = static_cast(cutlass::round_up( + cutlass::round_up(CLCQueryResponseStorage, 16) + + cutlass::round_up(TmemBasePtrsStorage, 16), + 16)); + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( PipelineStorage + + OtherStorage); +}; +} // namespace cutlass::gemm::collective::detail diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl new file mode 100644 index 0000000000..c937619c1d --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +// + +// + +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class MainloopPipelineStorage, + int stages +> +constexpr int +sm100_compute_stage_count_or_override(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class MainloopPipelineStorage, + int stages +> +constexpr int +sm100_compute_stage_count_or_override(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class MainloopPipelineStorage, + int carveout_bytes> +constexpr int +sm100_compute_stage_count_or_override(StageCountAutoCarveout stage_count) { + // For F8F6F4 sub-bytes, ElementA/B will be passed in as uint8_t + // For Planar Complex, ElementA/B will be passed in as cutlass::complex + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = (CollectiveMma::SharedStorage::PipelineStorage = PipelineTmaUmmaAsync) + constexpr auto mainloop_pipeline_bytes = sizeof(MainloopPipelineStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +check_input_datatypes() { + auto is_non_f4f6f8_input = [&]() { + return (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (cute::is_same_v); // For all MMA instrs except F4F6F8, A and B types should be the same. + }; + auto is_f4f6f8_input = [&]() { + // Allowed input element datatype for narrow precision GEMM + return ( + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) && + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) + ) || + ( + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) && + ( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v + ) + ); + }; + + static_assert(is_f4f6f8_input() || is_non_f4f6f8_input(), "Unsupported data type for ElementA"); + + return true; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + StageCountType, + KernelScheduleType, + cute::enable_if_t< + not cute::is_tuple_v && not cute::is_tuple_v && + not cute::is_complex_v && not cute::is_complex_v && + // Dense Gemm / PtrArrayDenseGemm + ( + (cute::is_base_of_v || + cute::is_same_v)) && + // Alignment check + detail::sm1xx_gemm_is_aligned()>> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(detail::check_input_datatypes(), "Incorrect input types"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + + static constexpr bool is_2sm = cute::is_base_of_v || + (not cute::is_base_of_v && + not cute::is_base_of_v && + cute::is_static_v && + cute::get<0>(ClusterShape_MNK{}) % 2 == 0 ); + + static_assert(detail::sm100_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, KernelScheduleType>()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using AtomThrID = typename TiledMma::AtomThrID; + + using AtomThrShapeMNK = cute::Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(cute::shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + // Kludged right divide to divide TileShape_M/N by 1SM/2SM + // Future work: fix partition_shape to account for hierarchies and + // contiguity so we can pass BlockTileA/B to sm100_smem_selector instead + using SmemShape_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(TileShape_MNK{}), shape_div(shape<1>(TileShape_MNK{}), size<1>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(TileShape_MNK{})); + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B( + ClusterShape_MNK{}, AtomThrID{})); + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShape_M, SmemShape_K>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, SmemShape_N, SmemShape_K>()); + static constexpr uint32_t TotalTmemRows = 128; + static constexpr uint32_t Sm100TmemCapacityColumns = 512; + static constexpr uint32_t TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static constexpr uint32_t AccumulatorPipelineStageCount = TotalTmem / (cute::size<0>(CtaTileShape_MNK{}) * cute::size<1>(CtaTileShape_MNK{})); + static_assert(AccumulatorPipelineStageCount > 0, "Accumulator pipeline stage count must be positive. This error probably means that TileShape_MNK and/or TiledMma::ThrLayoutVMNK are wrong."); + + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + using StrideA = cutlass::gemm::TagToStrideA_t; + using InternalStrideA = cute::remove_pointer_t; + // Grouped GEMM (where Stride type is Stride*) does not use CLC based scheduler. + // SchedulerPipelineStageCount could be set to zero for Grouped GEMM, but we shouldn't define CLC Pipeline's barrier arrays of size zero. + static constexpr uint32_t SchedulerPipelineStageCount = cute::is_same_v ? (AccumulatorPipelineStageCount + 1) : 1; + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< + ClusterShape_MNK, + AccumulatorPipelineStageCount, + SchedulerPipelineStageCount, + detail::CLCResponseSize, + IsArrayOfPointersGemm + >::KernelSmemCarveout; + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, and B."); + + using DispatchPolicy = + cute::conditional_t, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + > + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 6ec4daca45..c54cf9072a 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -39,4 +39,9 @@ #include "cutlass/gemm/collective/collective_builder_decl.hpp" #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" #include "cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl" +#if !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 9d8a1ba2d8..a57e5a082f 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -48,4 +48,10 @@ #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" +#if !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" +#endif // !defined(__CUDACC_RTC__) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp new file mode 100644 index 0000000000..0b0e2e3a27 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -0,0 +1,1268 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm100BlkScaledConfig = cutlass::detail::Sm100BlockScaledConfig; + using Blk_MN = typename Sm100BlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = cute::conditional_t; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = uint_bit_t>; + using BitTypeElementB = uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFA; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementSF const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + constexpr int tma_alignment_bits = 128; + auto init_M = tma_alignment_bits; + auto init_N = tma_alignment_bits; + auto init_K = tma_alignment_bits; + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + InternalLayoutSFA layout_SFA; + InternalLayoutSFB layout_SFB; + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + layout_SFA = args.layout_SFA; + layout_SFB = args.layout_SFB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = nullptr; + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + args.layout_SFA, + reinterpret_cast(args.ptr_SFB), + args.layout_SFB, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Represent the full tensor of Scale factors + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + if constexpr (IsGroupedGemmKernel) { + layout_SFA = params.layout_SFA[init_group]; + layout_SFB = params.layout_SFB[init_group]; + } + else { + layout_SFA = params.layout_SFA; + layout_SFB = params.layout_SFB; + } + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,2>(mSFB_tmp); + auto y = ceil_div(shape<0,2>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), shape<0,1>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), stride<0,1>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::Tensor const& accumulators, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + // Set tCtSFA and tCtSFB start addresses. Only update the TMEM column address by masking the address with 0x000001FF. + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapA, class TensorMapB, + class TensorMapSFA, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto [tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + if (k_tile_count > 0) { // first iteraion + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + if constexpr (IsOverlappingAccum) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + Tensor pSFA_tensormap = make_tensor(observed_tma_load_sfa_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{}); + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + + copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_sfa, tma_desc_sfb); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + mainloop_params.ptr_SFA[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_SFA = {1,1,1,1,1}; + cute::array prob_stride_SFA = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + ElementSF const* ptr_SF = nullptr; + Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfa_, tensor_sfa, + prob_shape_SFA, prob_stride_SFA); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFA) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + prob_shape_SFA, + prob_stride_SFA); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_SFA); + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + +private: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp new file mode 100644 index 0000000000..b28a30754d --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -0,0 +1,1092 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm100BlkScaledConfig = cutlass::detail::Sm100BlockScaledConfig; + using Blk_MN = typename Sm100BlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = cute::conditional_t; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + template< + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + STensorA tAsA; + STensorB tBsB; + // for scale factor tensor values + GTensorPartitionedSFA tAgSFA_mkl; + GTensorPartitionedSFB tBgSFB_nkl; + STensorSFA tAsSFA; + STensorSFB tBsSFB; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + uint16_t mcast_mask_sfa; + uint16_t mcast_mask_sfb; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + GTensorPartitionedSFA tAgSFA_mkl_, GTensorPartitionedSFB tBgSFB_nkl_, + STensorSFA tAsSFA_, STensorSFB tBsSFB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_, + uint16_t mcast_mask_sfa_, uint16_t mcast_mask_sfb_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , tAgSFA_mkl(tAgSFA_mkl_), tBgSFB_nkl(tBgSFB_nkl_) + , tAsSFA(tAsSFA_), tBsSFB(tBsSFB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) + , mcast_mask_sfa(mcast_mask_sfa_), mcast_mask_sfb(mcast_mask_sfb_) {} + }; + + template< + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + struct MmaParams { + TiledMma tiled_mma; + FragmentA tCrA; + FragmentB tCrB; + FragmentSFA tCtSFA; + FragmentSFB tCtSFB; + SFATiledCopy tiled_copy_s2t_SFA; + SmemFrgSFA thr_tCsSFA_s2t; + TmemFrgSFA thr_tCtSFA_s2t; + SFBTiledCopy tiled_copy_s2t_SFB; + SmemFrgSFB thr_tCsSFB_s2t; + TmemFrgSFB thr_tCtSFB_s2t; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_, FragmentSFA tCtSFA_, FragmentSFB tCtSFB_, + SFATiledCopy tiled_copy_s2t_SFA_, SmemFrgSFA thr_tCsSFA_s2t_, TmemFrgSFA thr_tCtSFA_s2t_, + SFBTiledCopy tiled_copy_s2t_SFB_, SmemFrgSFB thr_tCsSFB_s2t_, TmemFrgSFB thr_tCtSFB_s2t_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_), tCtSFA(tCtSFA_), tCtSFB(tCtSFB_) + , tiled_copy_s2t_SFA(tiled_copy_s2t_SFA_), thr_tCsSFA_s2t(thr_tCsSFA_s2t_), thr_tCtSFA_s2t(thr_tCtSFA_s2t_) + , tiled_copy_s2t_SFB(tiled_copy_s2t_SFB_), thr_tCsSFB_s2t(thr_tCsSFB_s2t_), thr_tCtSFB_s2t(thr_tCtSFB_s2t_) {} + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + // Check for SFA SFB layout requirement + const auto layout_sfa_ref = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_sfb_ref = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + implementable = implementable && (layout_sfa_ref == args.layout_SFA); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); + } + + implementable = implementable && (layout_sfb_ref == args.layout_SFB); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,2>(mSFB_tmp); + auto y = ceil_div(shape<0,2>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), shape<0,1>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), stride<0,1>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + LoadParams load_params { + size<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb // multicast masks + }; + return load_params; + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + MmaParams< + decltype(tCrA), decltype(tCrB), decltype(tCtSFA), decltype(tCtSFB), + decltype(tiled_copy_s2t_SFA), decltype(thr_tCsSFA_compact_s2t), decltype(thr_tCtSFA_compact_s2t), + decltype(tiled_copy_s2t_SFB), decltype(thr_tCsSFB_compact_s2t), decltype(thr_tCtSFB_compact_s2t) + > mma_params { + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t + }; + return mma_params; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if (k_tile_count > 0) { // first iteraion + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + if constexpr (IsOverlappingAccum) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + +private: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp new file mode 100644 index 0000000000..b652c89e27 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp @@ -0,0 +1,864 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = uint_bit_t>; + using BitTypeElementB = uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // Tensor pointers will be fixed before the first access + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::Tensor const& accumulators, + TensorStorage& shared_tensors, + [[maybe_unused]] uint32_t const tmem_nonaccum_offset) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +private: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp new file mode 100644 index 0000000000..c7e562507d --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -0,0 +1,723 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + template< + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + STensorA tAsA; + STensorB tBsB; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {} + }; + + template + struct MmaParams { + TiledMma tiled_mma; + FragmentA tCrA; + FragmentB tCrB; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_) {} + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + LoadParams load_params { + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + }; + return load_params; + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_tensors, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + MmaParams mma_params { + tiled_mma, + tCrA, tCrB + }; + return mma_params; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + +private: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 51ec28c7d0..9d3bb1ada4 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -388,6 +388,17 @@ class GemmUniversalAdapter< [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + ) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + [[maybe_unused]] void* kernel_params[] = {¶ms}; if constexpr (kEnableCudaHostAdapter) { @@ -415,6 +426,7 @@ class GemmUniversalAdapter< else { launch_result = cuda_adapter->launch(grid, cluster, + fallback_cluster, block, smem_size, stream, @@ -430,8 +442,7 @@ class GemmUniversalAdapter< else { CUTLASS_ASSERT(cuda_adapter == nullptr); [[maybe_unused]] void const* kernel = (void const*) device_kernel; - static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90 - ; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90; if constexpr (kClusterLaunch) { if constexpr (is_static_1x1x1) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) @@ -456,6 +467,42 @@ class GemmUniversalAdapter< grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); } } + + else { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + ) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with fall-back cluster"); +#endif + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params, + launch_with_pdl); + } + } + } + } } else { diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 6c98624367..62381192a7 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -35,6 +35,8 @@ #include "cute/layout.hpp" #include "cute/numeric/integral_constant.hpp" // cute::false_type +#include "cute/arch/copy_sm100.hpp" + ////////////////////////////////////////////////////////////////////////////// namespace cutlass::detail { @@ -346,6 +348,174 @@ struct MainloopSm90TmaGmmaWarpSpecializedSparseFP8 }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + +// Gemm with block scaling factors +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedBlockScaledSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + + + +// Ptr-Array Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelPtrArrayTmaWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + +// Ptr-Array Block Scaled GEMM +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelPtrArrayTmaWarpSpecializedBlockScaledSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + + + +////////////////////////////////////////////////////////////////////////////// + +// +// Collective Builder Tag Property +// + +struct KernelSchedule1Sm {}; +struct KernelSchedule2Sm {}; +struct KernelScheduleSm100 {}; +struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; + +struct KernelScheduleBlockScaledGemmSm100 : KernelScheduleSm100 {}; +struct KernelScheduleMxNvf4Sm100 : KernelScheduleBlockScaledGemmSm100 {}; +struct KernelScheduleMxf8f6f4Sm100 : KernelScheduleBlockScaledGemmSm100 {}; + +struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {}; +struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {}; +struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; +struct KernelSchedulePtrArrayMxf8f6f4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; + + +// +// Collective Builder Tag +// Only used in CollectiveBuilder +// + +// Dense GEMM: Specialize for 1SM vs 2SM +struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; +struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; + + + +// Block Scaled Dense GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM +struct KernelTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 { }; +struct KernelTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelScheduleBlockScaledGemmSm100 { }; +struct KernelTmaWarpSpecialized1SmNvf4Sm100 final : KernelSchedule1Sm, KernelScheduleMxNvf4Sm100 { }; +struct KernelTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; +struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelScheduleMxNvf4Sm100 { }; +struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; +struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; +struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; + + +// Ptr-Array Dense GEMM: Specialize for 1SM vs 2SM +struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {}; +struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {}; + + +// Ptr-Array Block Scaled Dense GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayBlockScaledGemmSm100 { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayBlockScaledGemmSm100 { }; +struct KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxNvf4Sm100 { }; +struct KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxNvf4Sm100 { }; +struct KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxNvf4Sm100 { }; +struct KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxNvf4Sm100 { }; +struct KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { }; +struct KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { }; + + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100TmaUmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; + + + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100TmaUmmaWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelTmaWarpSpecializedBlockScaledSm100; +}; + + + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100ArrayTmaUmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = false; + using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; +}; + +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm100; +}; + + + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 2b54758d9c..50245571db 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -63,4 +63,6 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cute::disjunction_v< + cutlass::detail::is_kernel_tag_of, + cutlass::detail::is_kernel_tag_of>>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using LayoutSFA = typename cutlass::detail::LayoutSFAType::type; + using LayoutSFB = typename cutlass::detail::LayoutSFBType::type; + using ElementSF = typename cutlass::detail::ElementSFType::type; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + using TileSchedulerTag = cute::conditional_t; + + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount, ProblemShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorMapStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + alignas(128) EpilogueTensorMapStorage epilogue; + alignas(128) MainloopTensorMapStorage mainloop; + } tensormaps; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + constexpr uint32_t NumEpilogueSubTiles = 1; + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + ProblemShape problem_shapes = args.problem_shape; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (!IsGroupedGemmKernel && sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + scheduler, + args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Dynamic Cluster or Preferred Cluster don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + constexpr bool IsBlockscaled = !cute::is_void_v; + if constexpr (IsBlockscaled) { + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + } + else { + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + } + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Mainloop + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + constexpr uint32_t NumEpilogueSubTiles = 1; + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Mainloop + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // NOTE: cluster_shape here is the major cluster shape, not fallback one + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape.get_host_problem_shape(), + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + auto problem_shape = params.problem_shape; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : WarpCategory::Epilogue; + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue) // epilogue + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if constexpr(!IsOverlappingAccum) { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (WarpCategory::MMA == warp_category && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } + } + + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0)); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + int32_t sm_id = static_cast(cutlass::arch::SmId()); + + // Calculate mask after cluster barrier arrival + mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); + accumulator_pipeline.init_masks(cluster_shape); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + TiledMma tiled_mma; + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + + pipeline_init_wait(cluster_size); + + if constexpr (IsGroupedGemmKernel) { + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + // In case user wants to engage less SMs than available on device + sm_id = blockIdx.x + (blockIdx.y * gridDim.x); + } + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + + if (is_participant.main_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, + shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id, work_tile_info.L_idx); + Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape, + curr_batch + ); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue, + did_batch_change + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter_next, k_tile_count - k_tile_prologue, + false /* did_batch_change - prologue loads handle tensormap acquire */ + ); + mainloop_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + + else if (is_participant.sched) { + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + // Grouped GEMM uses static tile scheduler + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + + + auto mma_inputs = collective_mainloop.mma_init( + params.mainloop, + collective_mainloop.slice_accumulator(accumulators, 0), + shared_storage.tensors.mainloop, + tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if constexpr (!IsOverlappingAccum) { + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + } + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + int acc_stage = (IsOverlappingAccum) ? (accumulator_pipe_producer_state.phase() ^ 1) : (accumulator_pipe_producer_state.index()); + auto accumulator = collective_mainloop.slice_accumulator(accumulators, acc_stage); + if (is_mma_leader_cta) { + mainloop_pipe_consumer_state = collective_mainloop.mma( + cute::make_tuple( + mainloop_pipeline, accumulator_pipeline), + cute::make_tuple( + mainloop_pipe_consumer_state, accumulator_pipe_producer_state), + accumulator, + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + ++accumulator_pipe_producer_state; + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + if constexpr (!IsOverlappingAccum) { + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + } + else { + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + + // Fetch a copy of tensormaps for the CTA from Params + auto epi_load_tensormap = get<0>(collective_epilogue.load_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape, + curr_batch + ); + } + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.template load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + cute::make_tuple(epi_load_tensormap, did_batch_change), + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + + auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); + bool do_tail_store = false; + // Fetch a copy of tensormaps for the CTA from Params + auto epi_store_tensormap = get<0>(collective_epilogue.store_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape, + curr_batch + ); + } + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Accumulator stage slice after making sure allocation has been performed + int acc_stage = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_consumer_state.phase(); + } + else { + return accumulator_pipe_consumer_state.index(); + } + }(); + + // Fusions may need problem shape for the current group + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + // + // Epilogue and write to gD + // + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + collective_mainloop.slice_accumulator(accumulators, acc_stage), + shared_storage.tensors.epilogue, + cute::make_tuple(epi_store_tensormap, did_batch_change) + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + + do_tail_store |= TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + if constexpr (IsOverlappingAccum) { + // Signal to peer MMA that Full TMEM alloc can be deallocated + if constexpr (has_mma_peer_cta) { + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank); + } + tmem_deallocation_result_barrier.arrive(); + } + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000000..5d03f92146 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp @@ -0,0 +1,1001 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/mainloop_fusion_helper_scale_factor.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cute::disjunction_v, + cutlass::detail::is_kernel_tag_of>>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using LayoutSFA = typename cutlass::detail::LayoutSFAType::type; + using LayoutSFB = typename cutlass::detail::LayoutSFBType::type; + using ElementSF = typename cutlass::detail::ElementSFType::type; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_load_pipe_increment(CtaShape_MNK{}); + + // Fixup performed for split-/stream-K is done across warps in different CTAs + // at epilogue subtile granularity. Thus, there must be one barrier per sub-tile per + // epilogue warp. + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + constexpr bool IsBlockscaled = !cute::is_void_v; + if constexpr (IsBlockscaled) { + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster shape check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + } + else { + // Special cluster shape check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + } + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // NOTE cluster_shape here is the major cluster shape, not fallback one + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto [M,N,K,L] = problem_shape_MNKL; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : WarpCategory::Epilogue; + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + collective_mainloop.prefetch_tma_descriptors(); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + collective_epilogue.prefetch_tma_descriptors(params.epilogue); + } + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue) // epilogue + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if constexpr(!IsOverlappingAccum) { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (WarpCategory::MMA == warp_category && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } + } + + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0)); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, shared_storage.tensors.mainloop); + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // + // TMEM "Allocation" + // + auto tmem_storage = collective_mainloop.template init_tmem_tensors(EpilogueTile{}); + + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, load_inputs.k_tiles); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + + else if (is_participant.sched) { + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + if constexpr (IsSchedDynamicPersistent) { + + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + + } + } + + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + auto mma_inputs = collective_mainloop.mma_init( + tmem_storage, + shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if constexpr (!IsOverlappingAccum) { + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + + // Accumulator stage slice + int acc_stage = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_producer_state.phase() ^ 1; + } + else { + return accumulator_pipe_producer_state.index(); + } + }(); + + if (is_mma_leader_cta) { + mainloop_pipe_consumer_state = collective_mainloop.mma( + cute::make_tuple(mainloop_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state), + collective_mainloop.slice_accumulator(tmem_storage, acc_stage), + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + if constexpr (!IsOverlappingAccum) { + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + } + else { + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.template load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Accumulator stage slice + int acc_stage = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_consumer_state.phase(); + } + else { + return accumulator_pipe_consumer_state.index(); + } + }(); + + auto accumulator = get<0>(collective_mainloop.slice_accumulator(tmem_storage, acc_stage)); + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulator, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulator, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + + do_tail_store = true; + } + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + if constexpr (IsOverlappingAccum) { + // Signal to peer MMA that Full TMEM alloc can be deallocated + if constexpr (has_mma_peer_cta) { + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank); + } + tmem_deallocation_result_barrier.arrive(); + } + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp new file mode 100755 index 0000000000..92ce8839b8 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp @@ -0,0 +1,723 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/int_tuple.hpp" + +#include "cutlass/arch/config.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm_coord.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/conv/detail.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel::detail { + +//////////////////// Blackwell Scheduler ///////////////////////// + +template< + class ClusterShape_, + uint32_t Stages_ +> +class PersistentTileSchedulerSm100 { + +private: + + using UnderlyingTileScheduler = PersistentTileSchedulerSm90; + +public: + + using ClusterShape = ClusterShape_; + using RasterOrder = UnderlyingTileScheduler::RasterOrder; + using RasterOrderOptions = UnderlyingTileScheduler::RasterOrderOptions; + static constexpr bool IsDynamicPersistent = true; + + static constexpr uint32_t Stages = Stages_; + + // CLC response is an opaque 16B value + struct CLCResponse { uint32_t data[4]; }; + + using WorkTileInfo = typename PersistentTileSchedulerSm90::WorkTileInfo; + + using Params = PersistentTileSchedulerSm100Params; + + using Pipeline = PipelineCLCFetchAsync; + using PipelineStorage = typename Pipeline::SharedStorage; + + using ThrottlePipeline = PipelineAsync; + using ThrottlePipelineStorage = typename ThrottlePipeline::SharedStorage; + + class SharedStorage { + public: + + CUTLASS_DEVICE PipelineStorage& pipeline() { return pipeline_; } + CUTLASS_DEVICE ThrottlePipelineStorage& throttle_pipeline() { return throttle_pipeline_; } + CUTLASS_DEVICE CLCResponse* data() { return data_; } + + private: + alignas(16) PipelineStorage pipeline_; + alignas(16) ThrottlePipelineStorage throttle_pipeline_; + alignas(16) CLCResponse data_[Stages]; + }; + + struct Arguments { + + Arguments() = default; + Arguments(Arguments const&) = default; + Arguments(Arguments&&) = default; + + CUTLASS_HOST_DEVICE + Arguments& + operator=(Arguments const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + Arguments& + operator=(Arguments &&) { + return *this; + } + + int max_swizzle_size = 1; + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; + }; + + // + // Static Host Methods + // + + template + static Params + to_underlying_arguments( + ProblemShapeMNKL problem_shape_mnkl, + TileShape tile_shape, + [[maybe_unused]] ClusterShape cluster_shape, + [[maybe_unused]] KernelHardwareInfo const& hw_info, + [[maybe_unused]] Arguments const& args, + [[maybe_unused]] void* workspace = nullptr, + [[maybe_unused]] uint32_t NumEpilogueSubTiles = 1, + [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u + ) { + + auto cs = cutlass::detail::select_cluster_shape(ClusterShape_{}, hw_info.cluster_shape); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cs); + + Params params; + params.initialize( + problem_blocks, + to_gemm_coord(cs), + hw_info, + args.max_swizzle_size, + args.raster_order + ); + return params; + } + + template + static Params + to_underlying_arguments( + ProblemShapeMNKL problem_shape_mnkl, + TileShape tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace = nullptr + ) { + + auto selected_cluster_shape = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, + atom_thr_shape_mnk, selected_cluster_shape); + + Params params; + params.initialize( + problem_blocks, + to_gemm_coord(selected_cluster_shape), + hw_info, + args.max_swizzle_size, + args.raster_order + ); + return params; + } + + // Conv Specialization + template + static Params + to_underlying_arguments( + cutlass::conv::ConvProblemShape problem_shape, + TileShape tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace = nullptr + ) { + + auto problem_shape_mnkl = [&] () { + // Infer im2col linearization from ConvOp and TileShape + constexpr bool is_linearized_M = (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) + && depth<0>(TileShape{}) == _0{}; + constexpr bool is_linearized_K = ConvOp == conv::Operator::kWgrad && depth<2>(TileShape{}) == _0{}; + + if constexpr (is_linearized_M || is_linearized_K) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + }(); + + return to_underlying_arguments( + problem_shape_mnkl, + tile_shape_mnk, + atom_thr_shape_mnk, + cluster_shape_mnk, + hw_info, + args, + workspace + ); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE + static dim3 + get_grid_shape( + Params const& params, + ProblemShapeMNKL problem_shape_mnk, + BlockShape cta_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + [[maybe_unused]] Arguments arguments) { + auto problem_shape_MNKL = append<4>(problem_shape_mnk, Int<1>{}); + auto grid = get_tiled_cta_shape_mnl(problem_shape_MNKL, cta_shape, cluster_shape); + return possibly_transpose_grid(params.raster_order_, params.divmod_cluster_shape_m_, params.divmod_cluster_shape_n_, grid); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE + static dim3 + get_grid_shape( + Params const& params, + ProblemShapeMNKL problem_shape_mnkl, + TileShape tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo hw_info) { + auto grid = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cluster_shape_mnk); + return possibly_transpose_grid(params.raster_order_, params.divmod_cluster_shape_m_, params.divmod_cluster_shape_n_, grid); + } + + // Possibly transpose the grid depending on rasterization order. + CUTLASS_HOST_DEVICE + static dim3 + possibly_transpose_grid(RasterOrder raster_order, FastDivmod divmod_cluster_shape_m, FastDivmod divmod_cluster_shape_n, dim3 grid) { + if (raster_order == RasterOrder::AlongN) { + // Swap grid.x and grid.y for AlongN rasterization order, since the CLC scheduler + // will schedule in AlongM order by default. + // + // Each grid dimension must also be a multiple of the corresponding cluster dimension, + // so we convert the untransposed x into the number of clusters along the M mode, + // and multiply this by cluster.n (and vice-versa for y). + auto tmp = grid.x; + grid.x = divmod_cluster_shape_n.divide(grid.y) * divmod_cluster_shape_m; + grid.y = divmod_cluster_shape_m.divide(tmp) * divmod_cluster_shape_n; + } + return grid; + } + + template + static size_t + get_workspace_size( + Arguments const& args, + ProblemShape problem_shape, + KernelHardwareInfo const& hw_info, + [[maybe_unused]] uint32_t reduction_warp_groups, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + [[maybe_unused]] uint32_t num_accumulator_mtxs = 1) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + auto cs = cutlass::detail::select_cluster_shape(ClusterShape_{}, hw_info.cluster_shape); + + return Params::get_workspace_size( + to_gemm_coord(problem_shape_mnkl), + GemmCoord(1, 1, 1), // Tile shape. Unused. + to_gemm_coord(cs), + hw_info, + args.max_swizzle_size, + args.raster_order + ); + } + + template + static size_t + get_workspace_size(Arguments const& args, ProblemShape problem_shape, TileShapeMNK, AtomThrShape, ClusterShape, KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, uint32_t num_accumulator_mtxs = 1) { + return get_workspace_size(args, problem_shape, hw_info, reduction_warp_groups, num_accumulator_mtxs); + } + + template + static cutlass::Status + initialize_workspace( + Arguments const& args, + void* workspace, + cudaStream_t stream, + ProblemShape const& problem_shape, + KernelHardwareInfo const& hw_info, + uint32_t, // reduction_warp_groups + uint32_t = 1, // epilogue_subtile + uint32_t = 1, // num_accumulator_mtxs + CudaHostAdapter *cuda_adapter = nullptr) { + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + auto cs = cutlass::detail::select_cluster_shape(ClusterShape_{}, hw_info.cluster_shape); + + return Params::initialize_workspace( + workspace, + stream, + to_gemm_coord(problem_shape_mnkl), + GemmCoord(1, 1, 1), // Tile shape. Unused. + to_gemm_coord(cs), + hw_info, + args.max_swizzle_size, + args.raster_order, + cuda_adapter + ); + } + + template + static cutlass::Status + initialize_workspace( + Arguments const& args, + void* workspace, + cudaStream_t stream, + ProblemShape const& problem_shape, + TileShapeMNK, + AtomThrShape, + ClusterShape, + KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, + uint32_t num_accumulator_mtxs = 1, + CudaHostAdapter *cuda_adapter = nullptr) { + + return initialize_workspace( + args, + workspace, + stream, + problem_shape, + hw_info, + reduction_warp_groups, + 1, // epilogue_subtile + num_accumulator_mtxs, + cuda_adapter + ); + } + + static bool + can_implement(Arguments const& args) { + return true; + } + + // + // Constructors + // + CUTLASS_DEVICE + PersistentTileSchedulerSm100(Params const& params) + : scheduler_params(params) {} + + CUTLASS_DEVICE + PersistentTileSchedulerSm100(CLCResponse* clc_response_ptr, Params const& params, dim3 block_id_in_cluster) + : clc_response_ptr_(clc_response_ptr), scheduler_params(params), block_id_in_cluster_(block_id_in_cluster) {} + + template + CUTLASS_DEVICE + PersistentTileSchedulerSm100(CLCResponse* clc_response_ptr, Params const& params, ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, dim3 block_id_in_cluster) + : PersistentTileSchedulerSm100(clc_response_ptr, params, block_id_in_cluster) {} + // + // Data Members + // + CLCResponse *clc_response_ptr_ = nullptr; + Params const& scheduler_params; + dim3 block_id_in_cluster_; + + // + // Work Tile API + // + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE + static WorkTileInfo + initial_work_tile_info(ClusterShape cluster_shape, Params const& params) { + WorkTileInfo work_tile{ + static_cast((blockIdx.x / cute::size<0>(cluster_shape)) * cute::size<0>(cluster_shape)), + static_cast((blockIdx.y / cute::size<1>(cluster_shape)) * cute::size<1>(cluster_shape)), + static_cast((blockIdx.z / cute::size<2>(cluster_shape)) * cute::size<2>(cluster_shape)), + true + }; + + possibly_transpose_work_tile(work_tile, params); + return work_tile; + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape cluster_shape) { + return initial_work_tile_info(cluster_shape, scheduler_params); + } + + CUTLASS_DEVICE + auto + work_tile_to_cta_coord(WorkTileInfo work_tile_info) { + // Get every cta coord in three dimensions of the cluster + auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = block_id_in_cluster_; + return make_coord( + work_tile_info.M_idx + static_cast(cta_m_in_cluster), + work_tile_info.N_idx + static_cast(cta_n_in_cluster), + _, + work_tile_info.L_idx + static_cast(cta_l_in_cluster) + ); + } + + // Convert CTA-level work tile info to cluster-level tile coord + CUTLASS_DEVICE + auto + work_tile_to_cluster_coord_mnkl(WorkTileInfo work_tile_info) const { + // TileScheduler works at CTA-level, kernel works at cluster-level + int m_coord = idx2crd(scheduler_params.divmod_cluster_shape_m_.divide(work_tile_info.M_idx), + scheduler_params.problem_tiles_m_); + int n_coord = idx2crd(scheduler_params.divmod_cluster_shape_n_.divide(work_tile_info.N_idx), + scheduler_params.problem_tiles_n_); + int l_coord = idx2crd(work_tile_info.L_idx, + scheduler_params.problem_tiles_l_); + return make_coord(m_coord, n_coord, _, l_coord); + } + + CUTLASS_DEVICE + static void + issue_clc_query(PipelineState state, uint32_t mbarrier_addr, CLCResponse* clc_response_ptr) { + #if defined(CUTLASS_ARCH_CLC_ENABLED) + uint32_t result_addr = cute::cast_smem_ptr_to_uint(reinterpret_cast( + &clc_response_ptr[state.index()])); + asm volatile( + "{\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n\t" + "}\n" + : + : "r"(result_addr), "r"(mbarrier_addr)); + #else + CUTLASS_NOT_IMPLEMENTED(); + #endif + } + + CUTLASS_DEVICE + static WorkTileInfo + work_tile_info_from_clc_response(uint32_t result_addr) { + WorkTileInfo work_tile_info; + uint32_t valid = 0; + + #if defined(CUTLASS_ARCH_CLC_ENABLED) + asm volatile( + "{\n" + ".reg .pred p1;\n\t" + ".reg .b128 clc_result;\n\t" + "ld.shared.b128 clc_result, [%4];\n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" + "selp.u32 %3, 1, 0, p1;\n\t" + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {%0, %1, %2, _}, clc_result;\n\t" + "}\n" + : "=r"(work_tile_info.M_idx), "=r"(work_tile_info.N_idx), "=r"(work_tile_info.L_idx), "=r"(valid) + : "r"(result_addr) + : "memory" + ); + + cutlass::arch::fence_view_async_shared(); + #else + CUTLASS_NOT_IMPLEMENTED(); + #endif + work_tile_info.is_valid_tile = (valid == 1); + return work_tile_info; + } + + CUTLASS_DEVICE + PipelineState + advance_to_next_work(Pipeline& clc_pipeline, PipelineState clc_pipe_producer_state) const { + uint32_t mbarrier_addr = clc_pipeline.producer_get_barrier(clc_pipe_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + clc_pipeline.producer_acquire(clc_pipe_producer_state); + + if (cute::elect_one_sync()) { + issue_clc_query(clc_pipe_producer_state, mbarrier_addr, clc_response_ptr_); + } + + ++clc_pipe_producer_state; + return clc_pipe_producer_state; + } + + // Kernel helper function to get next work tile + template + CUTLASS_DEVICE + auto + fetch_next_work( + WorkTileInfo work_tile_info, + TileSchedulerPipeline& scheduler_pipeline, + TileSchedulerPipelineState scheduler_pipe_consumer_state) { + + scheduler_pipeline.consumer_wait(scheduler_pipe_consumer_state); + auto new_work_tile_info = get_current_work(scheduler_pipe_consumer_state); + scheduler_pipeline.consumer_release(scheduler_pipe_consumer_state); + + // Return true to indicate that the tile scheduler pipeline state should be advanced + return cute::make_tuple(new_work_tile_info, true); + } + + // + // K Tile API + // + // Permute K iteration loading order from [C, S, R, T] to [S, R, T, C] for better L2 locality + template + CUTLASS_DEVICE + auto + get_k_tile_iterator(WorkTileInfo const& work_tile_info, ProblemShapeMNKL problem_shape_MNKL, TileShape tile_shape, Shape) { + constexpr int32_t rank_t = cute::rank<2>(ProblemShapeMNKL{}); + auto k_tiles = cute::ceil_div(cute::get<2>(problem_shape_MNKL), cute::get<2>(tile_shape)); + if constexpr (rank_t == 4) { + return cute::make_coord_iterator>(k_tiles); + } + else if constexpr (rank_t == 3) { + return cute::make_coord_iterator>(k_tiles); + } + else if constexpr (rank_t == 2) { + return cute::make_coord_iterator>(k_tiles); + } + else { + return cute::make_coord_iterator(k_tiles); + } + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); + } + + // Compatible with sm90 kernel layers + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const&) { + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the basic tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&, Params const&) { + return true; + } + + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&) { + return true; + } + + // Returns whether fixup is needed for `work_tile_info`. None of the work units returned by + // this scheduler require fixup, since none of the work units partition the reduction extent. + CUTLASS_HOST_DEVICE + static bool + requires_fixup(Params const& params, WorkTileInfo const work_tile_info) { + return false; + } + + // Performs the reduction across splits for a given output tile. No fixup is required for + // work units returned by this scheduler. + template + CUTLASS_DEVICE + void + fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { } + + template < + bool IsComplex, + class TiledMma, + class AccEngine, + class AccLayout, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class CopyOpT2R + > + CUTLASS_DEVICE + AccumulatorPipelineState + fixup( + TiledMma const& , + WorkTileInfo const&, + cute::Tensor&, + AccumulatorPipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + CopyOpT2R) const { + return acc_pipe_consumer_state; + } + + // Returns whether the current WorkTileInfo passed in should continue to be used. Since + // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo + // passed in should not be used after having been processed. + CUTLASS_DEVICE + static bool + continue_current_work(WorkTileInfo&) { + return false; + } + + // + // Implementation Helpers + // + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static dim3 + get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) { + auto grid_shape = shape(ceil_div(problem_shape_mnkl, blk_shape)); + auto grid_shape_up = round_up(product_each(grid_shape), cluster_shape); // Assumes ClusterShape is flat + return dim3(size<0>(grid_shape_up), // M + size<1>(grid_shape_up), // N + size<3>(grid_shape_up)); // L + } + + template + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, + TileShape tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk) { + auto [tiles_m, tiles_n, tiles_l] = product_each(ceil_div(select<0,1,3>(problem_shape_mnkl), take<0,2>(tile_shape_mnk))); + auto ctas_m = round_nearest(tiles_m * size<0>(atom_thr_shape_mnk), size<0>(cluster_shape_mnk)); + auto ctas_n = round_nearest(tiles_n * size<1>(atom_thr_shape_mnk), size<1>(cluster_shape_mnk)); + auto ctas_l = tiles_l; + + return {static_cast(ctas_m), + static_cast(ctas_n), + static_cast(ctas_l)}; + } + + + // Get clcID and success bit + [[nodiscard]] CUTLASS_DEVICE + WorkTileInfo + get_current_work(PipelineState state) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&clc_response_ptr_[state.index()]); + auto work_tile = work_tile_info_from_clc_response(smem_addr); + possibly_transpose_work_tile(work_tile); + return work_tile; + } + + // Set data SMEM ptr + CUTLASS_DEVICE + void + set_data_ptr(CLCResponse* clc_response_ptr) { + clc_response_ptr_ = clc_response_ptr; + } + + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return false; + } + + template + CUTLASS_DEVICE + static void + fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} + + + CUTLASS_DEVICE + auto + fetch_next_work(WorkTileInfo work_tile_info) { + return cute::make_tuple(work_tile_info, true); + } + + CUTLASS_DEVICE + static cute::tuple + possibly_transpose_work_tile(Params::RasterOrder raster_order, int32_t M_idx, int32_t N_idx, FastDivmod divmod_cluster_shape_m, FastDivmod divmod_cluster_shape_n) { + if (raster_order == Params::RasterOrder::AlongN) { + int cluster_m, remainder_m, cluster_n, remainder_n; + divmod_cluster_shape_m(cluster_m, remainder_m, M_idx); + divmod_cluster_shape_n(cluster_n, remainder_n, N_idx); + M_idx = cluster_n * divmod_cluster_shape_m.divisor + remainder_m; + N_idx = cluster_m * divmod_cluster_shape_n.divisor + remainder_n; + } + return cute::make_tuple(M_idx, N_idx); + } + + + CUTLASS_DEVICE + static void + possibly_transpose_work_tile(WorkTileInfo& work_tile_info, Params const& params) { + auto [M_idx, N_idx] = possibly_transpose_work_tile( + params.raster_order_, work_tile_info.M_idx, work_tile_info.N_idx, params.divmod_cluster_shape_m_, params.divmod_cluster_shape_n_); + work_tile_info.M_idx = M_idx; + work_tile_info.N_idx = N_idx; + } + + CUTLASS_DEVICE + void + possibly_transpose_work_tile(WorkTileInfo& work_tile_info) { + possibly_transpose_work_tile(work_tile_info, scheduler_params); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp new file mode 100755 index 0000000000..2e64eac3af --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -0,0 +1,309 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/arch/barrier.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel::detail { + +//////////////////// Blackwell Grouped Static Scheduler ///////////////////////// + +// This tile scheduler is a SM100 wrapper for scheduling by the SM90 Group tile scheduler. +// This helps to enable reusing SM90 group tile scheduling capability for SM100 kernels +// (e.g., support for CTA rasterization). + +// For Grouped GEMM, most common use case have Problem Shapes for all groups only on device. +// Therefore, we don't how many tiles there will be for the scheduler to hand out. +// Hence, we have a SM90 style static group scheduler that launches the largest grid possible. +// If we had access to host-side problem shapes, one could to use it to figure out the grid shape +// and thereafter use CLC query (which can then be linearized and mapped to an approriate tile coord). + +template +class PersistentTileSchedulerSm100Group { + +public: + using UnderlyingScheduler = PersistentTileSchedulerSm90Group; + using UnderlyingProblemShape = typename GroupProblemShape::UnderlyingProblemShape; + using Params = PersistentTileSchedulerSm100GroupParams; + using WorkTileInfo = typename UnderlyingScheduler::WorkTileInfo; + using Arguments = typename UnderlyingScheduler::Arguments; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; + struct CLCResponse { uint32_t data[4]; }; + + static constexpr bool IsDynamicPersistent = UnderlyingScheduler::IsDynamicPersistent; + +private: + UnderlyingScheduler scheduler_sm90; + +public: + template + static Params + to_underlying_arguments( + GroupProblemShape problem_shapes, + TileShape tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace = nullptr) { + + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic + static_assert(cute::is_static::value); + + auto selected_cluster_shape = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); + auto cta_shape = cute::conditional_return>( + shape_div(tile_shape_mnk, atom_thr_shape_mnk), // Dynamic Cluster: For 2SM kernels, use CTA tile shape for the underlying scheduler + shape_div(tile_shape_mnk, selected_cluster_shape)); // Static Cluster: Blackwell builders expects TileShape to be Cluster's Tile Shape, Hopper doesn't + + dim3 problem_blocks = get_tiled_cta_shape_mnl( + problem_shapes.groups(), + problem_shapes, + hw_info, + cta_shape, selected_cluster_shape); + + Params params; + params.initialize( + problem_blocks, + problem_shapes.groups(), + problem_shapes.problem_shapes, + problem_shapes.host_problem_shapes, + to_gemm_coord(cta_shape), + to_gemm_coord(selected_cluster_shape), + hw_info, + args.max_swizzle_size, + args.raster_order + ); + + return params; + } + + static bool + can_implement(Arguments const& args) { + return true; + } + + CUTLASS_DEVICE + PersistentTileSchedulerSm100Group() { } + + CUTLASS_DEVICE + PersistentTileSchedulerSm100Group(CLCResponse* /* clc_response_ptr */, Params const& params) + : scheduler_params(params), + scheduler_sm90(params.params_sm90_) { } + + CUTLASS_DEVICE + PersistentTileSchedulerSm100Group(CLCResponse* /* clc_response_ptr */, Params const& params, dim3 /* block_id_in_cluster */) + : scheduler_params(params), + scheduler_sm90(params.params_sm90_) { } + + template + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape cluster_shape) { + return scheduler_sm90.initial_work_tile_info(cluster_shape); + } + + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { + return UnderlyingScheduler::get_tiled_cta_shape_mnl(groups, problem_shapes, hw_info, cta_shape, cluster_shape); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE + static dim3 + get_grid_shape( + Params const& params, + GroupProblemShape problem_shapes, + BlockShape cta_shape, + [[maybe_unused]] AtomThrShape atom_thr_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info) { + dim3 problem_blocks = get_tiled_cta_shape_mnl( + problem_shapes.groups(), + problem_shapes, + hw_info, + cta_shape, + cluster_shape); + + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + Arguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.params_sm90_.log_swizzle_size_; + } + args.raster_order = params.params_sm90_.raster_order_ == RasterOrder::AlongN ? RasterOrderOptions::AlongN : RasterOrderOptions::AlongM; + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + args.max_swizzle_size, + args.raster_order, + /* truncate_by_problem_size = */true, + cute::is_static_v ? true : false + ); + } + + CUTLASS_DEVICE + static auto + work_tile_to_cta_coord(WorkTileInfo work_tile_info) { + // SM90 static scheduler implicitly handles CTA coord in a Cluster + return make_coord( + work_tile_info.M_idx, + work_tile_info.N_idx, + _, + work_tile_info.L_idx + ); + } + + // + // K Tile API + // + template + CUTLASS_DEVICE + auto + get_k_tile_iterator(WorkTileInfo const& work_tile_info, ProblemShape problem_shape_MNKL, TileShape tile_shape, Shape) { + auto k_tiles = cute::ceil_div(cute::get<2>(problem_shape_MNKL), cute::get<2>(tile_shape)); + return cute::make_coord_iterator(k_tiles); + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the Group tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&, Params const&) { + return true; + } + + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&) { + return true; + } + + // Returns whether fixup is needed for `work_tile_info`. None of the work units returned by + // this scheduler require fixup, since none of the work units partition the reduction extent. + CUTLASS_HOST_DEVICE + static bool + requires_fixup(Params const& params, WorkTileInfo const work_tile_info) { + return false; + } + + // Performs the reduction across splits for a given output tile. No fixup is required for + // work units returned by this scheduler. + template + CUTLASS_DEVICE + void + fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { } + + template + static size_t + get_workspace_size(Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, uint32_t, uint32_t = 1, uint32_t = 1) { + return 0; + } + + template + static size_t + get_workspace_size(Arguments const& args, ProblemShape problem_shape, TileShapeMNK, AtomThrShape, ClusterShape, KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, uint32_t num_accumulator_mtxs = 1) { + return 0; + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape_MNKL, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape_MNKL), cute::get<2>(tile_shape))); + } + + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const&) { + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + template + static cutlass::Status + initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape const&, KernelHardwareInfo const&, uint32_t, uint32_t = 1, uint32_t = 1, CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static cutlass::Status + initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape const&, TileShapeMNK, AtomThrShape, ClusterShape, KernelHardwareInfo const&, + uint32_t, uint32_t = 1, CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + // Kernel helper function to get next CLC ID + template + CUTLASS_DEVICE + auto + fetch_next_work( + WorkTileInfo work_tile_info, + [[maybe_unused]] CLCPipeline& clc_pipeline, + [[maybe_unused]] CLCPipelineState clc_pipe_consumer_state) { + + return scheduler_sm90.fetch_next_work(work_tile_info); + } + +private: + // + // Methods + // + [[nodiscard]] CUTLASS_DEVICE + static CLCResponse + load_query_response(uint32_t smem_ptr) { + return UnderlyingScheduler::load_query_response(smem_ptr); + } + // + // Storage + // + CLCResponse *clc_response_ptr_ = nullptr; + Params scheduler_params; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp new file mode 100644 index 0000000000..f7be566fc0 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp @@ -0,0 +1,979 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/arch/barrier.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel::detail { + +// Persistent Thread Block (TB) scheduler leveraging stream-K decomposition +template < + class TileShape, + class ClusterShape, + uint32_t Stages_ +> +class PersistentTileSchedulerSm100StreamK { + using UnderlyingScheduler = PersistentTileSchedulerSm100; + using UnderlyingStreamKScheduler = PersistentTileSchedulerSm90StreamK; + using InternalWorkTileInfo = typename UnderlyingScheduler::WorkTileInfo; + using InternalParams = typename UnderlyingScheduler::Params; + // Shapediv failures currently occur with tile shape N of 192 + static constexpr bool ForceDataParallel = size<1>(TileShape{}) == 192; + +public: + static constexpr uint32_t Stages = Stages_; + + using CLCResponse = typename UnderlyingScheduler::CLCResponse; + using WorkTileInfo = typename UnderlyingStreamKScheduler::WorkTileInfo; + using Arguments = typename UnderlyingStreamKScheduler::Arguments; + + using Params = PersistentTileSchedulerSm100StreamKParams; + using RasterOrder = PersistentTileSchedulerSm90Params::RasterOrder; + using RasterOrderOptions = PersistentTileSchedulerSm90Params::RasterOrderOptions; + + using SharedStorage = typename UnderlyingScheduler::SharedStorage; + using Pipeline = typename UnderlyingScheduler::Pipeline; + using ThrottlePipeline = typename UnderlyingScheduler::ThrottlePipeline; + + static constexpr bool IsDynamicPersistent = true; + + // Number of sub blocks in the kernel epilogue + static constexpr int EpilogueSubtiles = 1; + + CUTLASS_HOST_DEVICE + PersistentTileSchedulerSm100StreamK() { } + + CUTLASS_DEVICE + PersistentTileSchedulerSm100StreamK(Params const& params) + : sm100_scheduler_(params.sm100_params_) + , params_(params) + , block_id_in_cluster_(cute::block_id_in_cluster()) { + // Set the current linear idx to be equal to the linear idx of the first work tile to be computed + auto cs = make_shape( + params.sm100_params_.divmod_cluster_shape_m_.divisor, + params.sm100_params_.divmod_cluster_shape_n_.divisor, + Int<1>{}); + } + + CUTLASS_DEVICE + PersistentTileSchedulerSm100StreamK(CLCResponse* clc_response_ptr, Params const& params, dim3 block_id_in_cluster) + : sm100_scheduler_(clc_response_ptr, params.sm100_params_, block_id_in_cluster), + params_(params), + block_id_in_cluster_(block_id_in_cluster) { + // Set the current linear idx to be equal to the linear idx of the first work tile to be computed + auto cs = make_shape( + params.sm100_params_.divmod_cluster_shape_m_.divisor, + params.sm100_params_.divmod_cluster_shape_n_.divisor, + Int<1>{}); + } + + template + CUTLASS_DEVICE + PersistentTileSchedulerSm100StreamK(CLCResponse* clc_response_ptr, Params const& params, + ProblemShape problem_shape_mnkl, TileShapeMNK tile_shape, dim3 block_id_in_cluster) + : PersistentTileSchedulerSm100StreamK(clc_response_ptr, params, block_id_in_cluster) { } + + template + static Params + to_underlying_arguments( + ProblemShape problem_shape, + TileShape tile_shape, + [[maybe_unused]] ClusterShape cluster_shape, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + uint32_t ktile_start_alignment_count = 1u) { + + auto cs = cutlass::detail::select_cluster_shape(cluster_shape, hw_info.cluster_shape); + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cs); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + Params params; + params.initialize( + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(cs), + hw_info, + args.splits, + args.max_swizzle_size, + args.raster_order, + args.reduction_mode, + ForceDataParallel ? Params::DecompositionMode::DataParallel : args.decomposition_mode, + workspace, + ktile_start_alignment_count + ); + return params; + } + + template + static Params + to_underlying_arguments( + ProblemShape problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace = nullptr, + uint32_t ktile_start_alignment_count = 1u + ) { + + auto cs = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cs); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + Params params; + params.initialize( + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(cs), + hw_info, + args.splits, + args.max_swizzle_size, + args.raster_order, + args.reduction_mode, + ForceDataParallel ? Params::DecompositionMode::DataParallel : args.decomposition_mode, + workspace, + ktile_start_alignment_count + ); + + return params; + } + + static bool + can_implement(Arguments const& args) { + return UnderlyingStreamKScheduler::can_implement(args); + } + + CUTLASS_DEVICE + PipelineState + advance_to_next_work(Pipeline& clc_pipeline, PipelineState clc_pipe_producer_state) const { + return sm100_scheduler_.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Get clcID and success bit + [[nodiscard]] CUTLASS_DEVICE + WorkTileInfo + get_current_work(PipelineState state) { + InternalWorkTileInfo work_tile_info = sm100_scheduler_.get_current_work(state); + if (!work_tile_info.is_valid()) { + return invalid_work_tile(); + } + + return convert_work(work_tile_info); + } + // Given the inputs, computes the total number of output blocks this problem will compute over + template + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl(ProblemShape problem_shape_mnkl, TileShape blk_shape, ClusterShape cluster_shape) { + return UnderlyingScheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, blk_shape, cluster_shape); + } + + template + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl(ProblemShape problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk) { + return UnderlyingScheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cluster_shape_mnk); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE + static dim3 + get_grid_shape( + Params const& params, + ProblemShape problem_shape, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + [[maybe_unused]] Arguments arguments) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); + return params.get_grid_shape(problem_blocks, to_gemm_coord(cluster_shape)); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE + static dim3 + get_grid_shape( + Params const& params, + ProblemShape problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo hw_info) { + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cluster_shape_mnk); + return params.get_grid_shape(problem_blocks, to_gemm_coord(cluster_shape_mnk)); + } + + + // Returns the initial work tile info that will be computed over + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape cluster_shape) { + InternalWorkTileInfo work_tile_info = UnderlyingScheduler::initial_work_tile_info(cluster_shape, params_.sm100_params_); + work_tile_info.is_valid_tile = false; + return convert_work(work_tile_info); + } + + // Returns a CTA-tiled coordinate for the provided work tile info + CUTLASS_DEVICE + auto + work_tile_to_cta_coord(WorkTileInfo const& work_tile_info) { + if (is_dp_only()) { + // For data-parallel decompositions, simply default to the + // underlying SM100 scheduler. + auto underlying_work_tile = to_underlying_work_tile_info(work_tile_info); + return sm100_scheduler_.work_tile_to_cta_coord(underlying_work_tile); + } + else { + // The SM90 stream-K scheduler already operates only at CTA level, + // so the returned work tile info already contains CTA offsets within + // each cluster tile. + return cute::make_coord( + work_tile_info.M_idx, + work_tile_info.N_idx, + _, + work_tile_info.L_idx + ); + } + } + + // Returns whether the current work_tile_info passed in should continue to be used. + CUTLASS_DEVICE + bool + continue_current_work(WorkTileInfo& work_tile_info) const { + return UnderlyingStreamKScheduler::continue_current_work_for_linear_idx( + current_work_linear_idx_, unit_iter_start_, block_id_in_cluster_, work_tile_info, params_.sk_params_); + } + + // Kernel helper function to get next CLC ID and whether to advance the CLC pipeline state. + template + CUTLASS_DEVICE + cute::tuple + fetch_next_work( + WorkTileInfo work_tile_info, + CLCPipeline& clc_pipeline, + CLCPipelineState clc_pipe_consumer_state) { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. Return `false` to indicate that the CLC pipeline state + // need not be advanced. + if (continue_current_work(work_tile_info)) { + return cute::make_tuple(work_tile_info, false); + } + + clc_pipeline.consumer_wait(clc_pipe_consumer_state); + auto new_work_tile_info = get_current_work(clc_pipe_consumer_state); + clc_pipeline.consumer_release(clc_pipe_consumer_state); + + // Return true to indicate that the CLC pipeline state should be advanced + return cute::make_tuple(new_work_tile_info, true); + } + + CUTLASS_DEVICE + cute::tuple + fetch_next_work(WorkTileInfo work_tile_info) { + return cute::make_tuple(work_tile_info, true); + } + + // Set data SMEM ptr + CUTLASS_DEVICE + void + set_data_ptr(CLCResponse* clc_response_ptr) { + sm100_scheduler_.set_data_ptr(clc_response_ptr); + } + + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return false; + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const& work_tile_info, Params const& params) { + return UnderlyingStreamKScheduler::compute_epilogue(work_tile_info, params.sk_params_); + } + + // Non-static variant of compute_epilogue. Used in cases where passing + // in Params is inconvenient. + CUTLASS_HOST_DEVICE + bool + compute_epilogue(WorkTileInfo const& work_tile_info) const { + return UnderlyingStreamKScheduler::compute_epilogue(work_tile_info, params_.sk_params_); + } + + template + static size_t + get_workspace_size( + Arguments const& args, + ProblemShape problem_shape, + KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + auto cs = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + TileShape tile_shape; + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cs); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + return Params::get_workspace_size( + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(tile_shape), + to_gemm_coord(cs), + hw_info, + args.splits, + args.max_swizzle_size, + args.raster_order, + ForceDataParallel ? Params::DecompositionMode::DataParallel : args.decomposition_mode, + args.reduction_mode, + reduction_warp_groups, + sizeof_bits::value, + sizeof_bits::value, + EpilogueSubtiles, + num_accumulator_mtxs, + ktile_start_alignment_count + ); + } + + template + static size_t + get_workspace_size( + Arguments const& args, + ProblemShape problem_shape, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + auto cs = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cs); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + auto cta_tile_shape_mnk = shape_div(tile_shape_mnk, atom_thr_shape_mnk); + + return Params::get_workspace_size( + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(cta_tile_shape_mnk), + to_gemm_coord(cs), + hw_info, + args.splits, + args.max_swizzle_size, + args.raster_order, + ForceDataParallel ? Params::DecompositionMode::DataParallel : args.decomposition_mode, + args.reduction_mode, + reduction_warp_groups, + sizeof_bits::value, + sizeof_bits::value, + EpilogueSubtiles, + num_accumulator_mtxs, + ktile_start_alignment_count + ); + } + + template + static cutlass::Status + initialize_workspace( + Arguments const& args, + void* workspace, + cudaStream_t stream, + ProblemShape const& problem_shape, + KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + uint32_t num_accumulator_mtxs = 1, + CudaHostAdapter *cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + auto cs = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + TileShape tile_shape; + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cs); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + return Params::initialize_workspace( + workspace, + stream, + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(tile_shape), + to_gemm_coord(cs), + hw_info, + args.splits, + args.max_swizzle_size, + args.raster_order, + ForceDataParallel ? Params::DecompositionMode::DataParallel : args.decomposition_mode, + args.reduction_mode, + reduction_warp_groups, + sizeof_bits::value, + sizeof_bits::value, + EpilogueSubtiles, + num_accumulator_mtxs, + cuda_adapter, + ktile_start_alignment_count + ); + } + + template + static cutlass::Status + initialize_workspace( + Arguments const& args, + void* workspace, + cudaStream_t stream, + ProblemShape const& problem_shape, + TileShapeMNK tile_shape_mnk, + AtomThrShape atom_thr_shape_mnk, + ClusterShape cluster_shape_mnk, + KernelHardwareInfo const& hw_info, + uint32_t reduction_warp_groups, + uint32_t num_accumulator_mtxs = 1, + CudaHostAdapter *cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + auto cs = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); + + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cs); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + auto cta_tile_shape_mnk = shape_div(tile_shape_mnk, atom_thr_shape_mnk); + + return Params::initialize_workspace( + workspace, + stream, + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(cta_tile_shape_mnk), + to_gemm_coord(cs), + hw_info, + args.splits, + args.max_swizzle_size, + args.raster_order, + ForceDataParallel ? Params::DecompositionMode::DataParallel : args.decomposition_mode, + args.reduction_mode, + reduction_warp_groups, + sizeof_bits::value, + sizeof_bits::value, + EpilogueSubtiles, + num_accumulator_mtxs, + cuda_adapter, + ktile_start_alignment_count + ); + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape, TileShapeMNK) { + return work_tile_info.k_tile_count; + } + + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const& work_tile_info) { + return work_tile_info.K_idx; + } + + template + CUTLASS_DEVICE + auto + get_k_tile_iterator(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShapeMNK tile_shape, Shape) { + // Get the shape of k tiles instead of the counter. Otherwise, if the problem shape has + // multiple k modes, the DMA loop would need to decompose the iterator onto every mode + // every time global loading happens. This would incur extra overhead. + auto k_tiles = cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape)); + auto k_tile_start = get_work_k_tile_start(work_tile_info); + // Iterate start from current k tile start over the k tiles shape. + return cute::make_coord_iterator(idx2crd(k_tile_start, k_tiles), k_tiles); + } + + // Returns whether fixup is needed for `work_tile_info`. + CUTLASS_HOST_DEVICE + bool + requires_fixup(WorkTileInfo const work_tile_info) const { + return UnderlyingStreamKScheduler::requires_fixup(params_.sk_params_, work_tile_info); + } + + // Performs the reduction across splits for a given output tile. + template + CUTLASS_DEVICE + void + fixup( + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx, + uint32_t num_accumulator_mtxs = 1) const { + + using BarrierManager = SyncManager; + + UnderlyingStreamKScheduler s; + return s.template fixup_helper( + params_.sk_params_, work_tile_info, accumulators, num_barriers, barrier_idx, num_accumulator_mtxs); + } + + + // Performs the reduction across splits for a given output tile. + template + CUTLASS_DEVICE + static void + fixup( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + UnderlyingStreamKScheduler::fixup(params.sk_params_, work_tile_info, accumulators, num_barriers, barrier_idx); + } + + // Performs reduction across splits for a given output tile + template < + bool IsComplex, + class TiledMma, + class AccEngine, + class AccLayout, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class CopyOpT2R + > + CUTLASS_DEVICE + AccumulatorPipelineState + fixup( + TiledMma const& tiled_mma, + WorkTileInfo const& work_tile_info, + cute::Tensor& accumulators, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + CopyOpT2R) const { + using namespace cute; + static_assert(cute::is_rmem_v || cute::is_tmem_v, "Accumulator must be in either TMEM or RF"); + + if constexpr (ForceDataParallel) { + return acc_pipe_consumer_state; + } + else { + if (!requires_fixup(work_tile_info)) { + if constexpr (cute::is_tmem_v) { + if (!work_tile_info.is_valid()) { + // The first work tile can be invalid, but still must release TMEM + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + } + return acc_pipe_consumer_state; + } + + if constexpr (cute::is_tmem_v) { + // When accumulators reside in TMEM, perform TMEM -> RF loads before performing fixup, + // and perform RF -> TMEM stores after fixup (when the split must compute the epilogue) + if constexpr (IsComplex) { + constexpr uint32_t NumAccumulatorMtx = 2; + Tensor accumulators_real = accumulators(_,_,_,0); + tmem_fixup( + tiled_mma, + work_tile_info, + accumulators_real, + acc_pipeline, + acc_pipe_consumer_state, + CopyOpT2R{}, + NumAccumulatorMtx, + 0 /*idx_accumulator_mtx*/ + ); + + Tensor accumulators_imag = accumulators(_,_,_,1); + return tmem_fixup( + tiled_mma, + work_tile_info, + accumulators_imag, + acc_pipeline, + acc_pipe_consumer_state, + CopyOpT2R{}, + NumAccumulatorMtx, + 1 /*idx_accumulator_mtx*/ + ); + } + else { + return tmem_fixup( + tiled_mma, + work_tile_info, + accumulators, + acc_pipeline, + acc_pipe_consumer_state, + CopyOpT2R{} + ); + } + } + else { + // Simply perform fixup without TMEM loads when accumulators reside in RF + constexpr uint32_t ThreadsForFixup = NumThreadsPerWarpGroup; + constexpr uint32_t Offset = static_cast(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); + constexpr uint32_t MaxNumNamedBarriers = 1; + constexpr uint32_t BarrierIdx = 0; + using BarrierManager = NamedBarrierManager; + constexpr int NumAccumulatorMtx = IsComplex ? 2 : 1; + + UnderlyingStreamKScheduler::template fixup_helper, BarrierManager>( + params_.sk_params_, work_tile_info, accumulators, MaxNumNamedBarriers, BarrierIdx, NumAccumulatorMtx); + return acc_pipe_consumer_state; + } + } + } + + + + // Convert CTA-level work tile info to cluster-level tile coord + CUTLASS_DEVICE + auto + work_tile_to_cluster_coord_mnkl(WorkTileInfo work_tile_info) const { + typename UnderlyingScheduler::WorkTileInfo tmp{ + work_tile_info.M_idx, + work_tile_info.N_idx, + work_tile_info.L_idx, + work_tile_info.is_valid() + }; + return sm100_scheduler_.work_tile_to_cluster_coord_mnkl(tmp); + } + +private: + CUTLASS_HOST_DEVICE + WorkTileInfo invalid_work_tile() const { + // Mark the work tile as invalid based on its having a 0 K tiles to comptue. + // Set the M, N, and L indices to be outside of the range of valid tiles for the problem. + return { + static_cast(params_.sm100_params_.problem_tiles_m_) * params_.sm100_params_.divmod_cluster_shape_m_.divisor, + static_cast(params_.sm100_params_.problem_tiles_n_) * params_.sm100_params_.divmod_cluster_shape_n_.divisor, + 0, // K_idx + static_cast(params_.sm100_params_.problem_tiles_l_), + 0 // k_tile_count + }; + } + + // Converts the work tile info returned by the SM100 scheduler to a linear index + CUTLASS_DEVICE + uint64_t + to_linear_idx( + InternalWorkTileInfo const& work_tile_info, + Params const& params) { + // The InternalWorkTileInfo returned from CLC query gives all CTAs in a cluster + // the tile offset corresponding to the first CTA tile in the cluster tile assigned + // to the cluster. Since the SM90 tile scheduler operates at CTA level, we must assign + // each CTA its own tile when computing the linear ID to be used by the SM90 + // stream-K scheduler. + auto start_cta_m_preferred_cluster = params.sk_params_.truncate_to_cluster_size_m(work_tile_info.M_idx); + auto start_cta_n_preferred_cluster = params.sk_params_.truncate_to_cluster_size_n(work_tile_info.N_idx); + uint64_t cluster_idx = gridDim.y * start_cta_m_preferred_cluster + start_cta_n_preferred_cluster; + uint64_t sm_count = gridDim.x * gridDim.y; + uint64_t wave_idx = work_tile_info.L_idx; + + auto cluster_start_linear_id = sm_count * wave_idx + cluster_idx; + + // Determine the offset of this CTA in the preferred cluster shape. + // This calculation aims to accomodate both cases in which this CTA is part of a preferred cluster + // and those in which it is part of a fallback cluster. + // + // The calculation is performed by computing the starting M and N index of the preferred cluster that + // this CTA would be in, and then subtracting these from the true CTA M and N indexes. + // + // In the case where this CTA is part of a preferred cluster, the resulting offsets are equivalent + // to those returned by cute::block_id_in_cluster(); + auto [cta_m_in_cluster, cta_n_in_cluster, _] = block_id_in_cluster_; + uint64_t cta_m_in_preferred_cluster = work_tile_info.M_idx + cta_m_in_cluster - start_cta_m_preferred_cluster; + uint64_t cta_n_in_preferred_cluster = work_tile_info.N_idx + cta_n_in_cluster - start_cta_n_preferred_cluster; + + if (params.sk_params_.raster_order_ == RasterOrder::AlongN) { + return cluster_start_linear_id + (params.sk_params_.divmod_cluster_shape_minor_.divisor * cta_n_in_preferred_cluster) + cta_m_in_preferred_cluster; + } + else { + return cluster_start_linear_id + (params.sk_params_.divmod_cluster_shape_minor_.divisor * cta_m_in_preferred_cluster) + cta_n_in_preferred_cluster; + } + } + + // Converts the work tile info returned by the SM100 scheduler to a stream-K work tile info + CUTLASS_DEVICE + WorkTileInfo + convert_work(InternalWorkTileInfo const& work_tile_info) { + if (has_sk_work()) { + current_work_linear_idx_ = to_linear_idx(work_tile_info, params_); + auto work = UnderlyingStreamKScheduler::get_current_work_for_linear_idx(unit_iter_start_, current_work_linear_idx_, block_id_in_cluster_, params_.sk_params_); + if (!work.is_valid()) { + return invalid_work_tile(); + } + return work; + } + else if (is_split_k()) { + // Split-K offsets are returned directly by CLC query (rather than being + // returned by the SM90 stream-K tile scheduler). CLC query returns + // the first CTA tile of work for each CTA in a cluster, but later use of the + // split-K work tile for fixup expect a CTA-offset tile. Thus, we need to offset + // each CTA's M and N index by the CTA offset in the cluster. + auto [cta_m_in_cluster, cta_n_in_cluster, _] = block_id_in_cluster_; + auto M_idx = work_tile_info.M_idx + cta_m_in_cluster; + auto N_idx = work_tile_info.N_idx + cta_n_in_cluster; + + int L_idx, Split_idx; + params_.sk_params_.divmod_splits_(L_idx, Split_idx, work_tile_info.L_idx); + + // TODO: Modularize the SM90 scheduler to pull out and reuse this redundant code + int additional_k_tiles = 0; + int split_start_offset = params_.sk_params_.big_units_; + + if (Split_idx < params_.sk_params_.big_units_) { + // Offsets for "big" units. One additional k iteration is performed, + // and each split preceding us was a big unit, so we must increase + // our split starting offset by our split ID (Split_idx). + additional_k_tiles = 1; + split_start_offset = Split_idx; + } + + // Set up k iteration count and split starting iteration assuming the + // iteration space is evenly split. + uint32_t k_tiles = params_.sk_params_.divmod_k_tiles_per_sk_unit_.divisor; + uint32_t K_idx = Split_idx * k_tiles; + + // Apply any fixup needed to handle residuals + K_idx += split_start_offset; + k_tiles += additional_k_tiles; + + // K_idx is even for each cta. + // + // * Example + // 53 k_tiles per output tile + // 10 k_tiles for normal size split + // 11 k_tiles for start three big unit + // + // split 0 : K_idx = [0, 10], k_tiles = 11 -> K_idx = [0, 11], k_tiles = 12 + // split 1 : K_idx = [11, 21], k_tiles = 11 -> K_idx = [12, 21], k_tiles = 10 + // split 2 : K_idx = [22, 32], k_tiles = 11 -> K_idx = [22, 33], k_tiles = 12 + // split 3 : K_idx = [33, 42], k_tiles = 10 -> K_idx = [34, 42], k_tiles = 9 -> K_idx = [34, 43], k_tiles = 10 + // split 4 : K_idx = [43, 52], k_tiles = 10 -> K_idx = [44, 52], k_tiles = 9 + if (params_.sk_params_.ktile_start_alignment_count_ == 2u && K_idx % 2 != 0) { + // If current cta K_idx not start from even, give up one k_tile + K_idx += 1; + k_tiles -= 1; + } + if (params_.sk_params_.ktile_start_alignment_count_ == 2u && + (K_idx + k_tiles) % 2 != 0 && + (K_idx + k_tiles) < params_.sk_params_.divmod_tiles_per_output_tile_.divisor) { + // If next cta K_idx not start from even, acquire one k_tile + k_tiles += 1; + } + + return { + static_cast(M_idx), + static_cast(N_idx), + static_cast(K_idx), + static_cast(L_idx), + k_tiles, + k_tiles // remaining iterations + }; + } + else { + // Data-parallel case + return { + static_cast(work_tile_info.M_idx), + static_cast(work_tile_info.N_idx), + static_cast(0), // K_idx + static_cast(work_tile_info.L_idx), + static_cast(params_.sk_params_.divmod_tiles_per_output_tile_.divisor), + static_cast(params_.sk_params_.divmod_tiles_per_output_tile_.divisor) + }; + } + } + + // Converts a WorkTileInfo struct to the WorkTileInfo representation + // of the underlying SM100 scheduler. + CUTLASS_HOST_DEVICE static + InternalWorkTileInfo + to_underlying_work_tile_info(WorkTileInfo const& work_tile_info) { + return { + work_tile_info.M_idx, + work_tile_info.N_idx, + work_tile_info.L_idx, + work_tile_info.is_valid() + }; + } + + // Returns whether the current parameters contain only data-parallel tiles + CUTLASS_HOST_DEVICE + bool + is_dp_only() const { + return params_.sk_params_.sk_units_ == 0 && params_.sk_params_.divmod_splits_.divisor == 1; + } + + // Returns whether the current parameters are for a split-K decomposition + CUTLASS_HOST_DEVICE + bool + is_split_k() const { + return params_.sk_params_.divmod_splits_.divisor > 1; + } + + // Returns whether the current parameters contain any stream-K work + CUTLASS_HOST_DEVICE + bool + has_sk_work() const { + return params_.sk_params_.sk_units_ > 0; + } + + // Performs reduction across splits for a given output tile + template < + class TiledMma, + class AccEngine, + class AccLayout, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class CopyOpT2R + > + CUTLASS_DEVICE + AccumulatorPipelineState + tmem_fixup( + TiledMma const& tiled_mma, + WorkTileInfo const& work_tile_info, + cute::Tensor& accumulators, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + CopyOpT2R, + uint32_t num_accumulator_mtx = 1, + uint32_t idx_accumulator_mtx = 0) const { + using namespace cute; + static_assert(cute::is_tmem_v, "Accumulator must be in TMEM"); + + using ElementAccumulator = typename AccEngine::element_type; + + constexpr uint32_t ThreadsForFixup = NumThreadsPerWarpGroup; + constexpr uint32_t Offset = static_cast(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); + constexpr uint32_t MaxNumNamedBarriers = 1; + constexpr uint32_t BarrierIdx = 0; + using BarrierManager = NamedBarrierManager; + + // When accumulators reside in TMEM, perform TMEM -> RF loads before performing fixup, + // and perform RF -> TMEM stores after fixup (when the split must compute the epilogue) + auto dummy_gmem_workspace = make_tensor( + make_gmem_ptr(nullptr), + make_layout(take<0,2>(TileShape{}), GenRowMajor{})); // (TILE_M,TILE_N) + + auto dummy_gmem_buffer = tiled_mma.get_slice(0).partition_C(dummy_gmem_workspace); // (MMA,MMA_M,MMA_N) + + auto tmem_load = make_tmem_copy(CopyOpT2R{}, accumulators); + auto tmem_store = make_tmem_copy(cute::TMEM::tmem_load_to_store(CopyOpT2R{}), accumulators); + + auto thr_tmem_load = tmem_load.get_slice(threadIdx.x % ThreadsForFixup); + auto thr_tmem_store = tmem_store.get_slice(threadIdx.x % ThreadsForFixup); + + Tensor tCtAcc = thr_tmem_load.partition_S(accumulators); // (TMEM_LOAD,TMEM_LOAD_MMA,TMEM_LOAD_M,TMEM_LOAD_N) + Tensor tCgAcc = thr_tmem_load.partition_D(dummy_gmem_buffer); // (TMEM_LOAD,TMEM_LOAD_MMA,TMEM_LOAD_M,TMEM_LOAD_N) + auto tCrAcc = make_tensor(shape(tCgAcc)); // (TMEM_LOAD,TMEM_LOAD_MMA,TMEM_LOAD_M,TMEM_LOAD_N) + + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + + // Copy accumulators from tmem to rmem for reduction + copy(tmem_load, tCtAcc, tCrAcc); + + bool should_compute_epilogue = compute_epilogue(work_tile_info); + if (!should_compute_epilogue && (idx_accumulator_mtx == (num_accumulator_mtx - 1))) { + // Splits that do not compute the epilogue must advance the accumulator pipeline + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Perform fixup + UnderlyingStreamKScheduler::template fixup_helper( + params_.sk_params_, work_tile_info, tCrAcc, MaxNumNamedBarriers, BarrierIdx, num_accumulator_mtx, idx_accumulator_mtx); + + if (should_compute_epilogue) { + // Splits that compute the epilogue copy the reduced accumulators back to tmem for + // the epilogue to compute on it + copy(tmem_store, tCrAcc, tCtAcc); + } + + return acc_pipe_consumer_state; + } + + + // + // Members + // + + UnderlyingScheduler sm100_scheduler_; + Params params_; + dim3 block_id_in_cluster_; + uint64_t current_work_linear_idx_ = 0; + uint32_t unit_iter_start_ = 0; + + // This might not be needed + bool is_fallback_cluster_ = false; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index c19f33fb9a..b65c45c277 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -120,6 +120,7 @@ class GemmUniversal< typename detail::TileSchedulerSelector< GroupScheduler, ArchTag, TileShape, ClusterShape, + 2, // Default unused parameter - SchedulerPipelineStageCoun ProblemShape>::Scheduler, typename detail::TileSchedulerSelector< void, ArchTag, TileShape, ClusterShape>::Scheduler>; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 62096e825f..6311c60131 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -120,6 +120,7 @@ class GemmUniversal< typename detail::TileSchedulerSelector< GroupScheduler, ArchTag, TileShape, ClusterShape, + 2, // Default unused parameter - SchedulerPipelineStageCoun ProblemShape>::Scheduler, typename detail::TileSchedulerSelector< void, ArchTag, TileShape, ClusterShape>::Scheduler>; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index cafab8b99f..f00a69bb27 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -401,9 +401,7 @@ class GemmUniversal< // Compute m_coord, n_coord, and l_coord with their post-tiled shapes auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); - - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl), compact_col_major(shape<2>(gB_nkl))); - + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); // handles the difference between the rank of Tensor returned by load_input in case they do not have a batch mode auto l_coord = [&] (auto const& gB_nkl_) { // gB_nkl needs to be passed into the lambda because C++17 diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index b3413c8588..960f917d82 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -714,6 +714,19 @@ class PersistentTileSchedulerSm90StreamK { auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = block_id_in_cluster; uint64_t cta_m_in_cluster = static_cast(cta_m_in_cluster_); uint64_t cta_n_in_cluster = static_cast(cta_n_in_cluster_); + + // Determine the CTA's M and N offsets within the preferred cluster + // This simply finds the linear offset of the CTA within the cluster, and takes a divmod + // on it depending on the rasterization order used by the scheduler. + uint64_t cluster_linear_work_idx_tmp = params.div_cluster_size(linear_idx) * params.get_cluster_size(); + + if (params.raster_order_ == RasterOrder::AlongN) { + params.divmod_cluster_shape_minor_(cta_n_in_cluster, cta_m_in_cluster, linear_idx - cluster_linear_work_idx_tmp); + } + else { + params.divmod_cluster_shape_minor_(cta_m_in_cluster, cta_n_in_cluster, linear_idx - cluster_linear_work_idx_tmp); + } + return {static_cast(cta_m_in_cluster), static_cast(cta_n_in_cluster), _}; } @@ -946,6 +959,8 @@ class PersistentTileSchedulerSm90StreamK { params.divmod_cluster_blk_major_, params.log_swizzle_size_, params.raster_order_ + , cta_m_in_cluster + , cta_n_in_cluster ); // Set the M, N, and L block offsets diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp index a524630948..b612165eb1 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -58,6 +58,9 @@ struct GroupScheduler { }; // Only used for Grouped GEMMs #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp" //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel::detail { @@ -71,6 +74,7 @@ template < class ArchTag, class TileShape, class ClusterShape + , uint32_t SchedulerPipelineStageCount = 2 , class ProblemShapeType = void > struct TileSchedulerSelector { @@ -82,12 +86,14 @@ template < class ArchTag, class TileShape, class ClusterShape + , uint32_t SchedulerPipelineStageCount > struct TileSchedulerSelector< PersistentScheduler, ArchTag, TileShape, ClusterShape + , SchedulerPipelineStageCount > { using Scheduler = PersistentTileSchedulerSm90; }; @@ -97,30 +103,35 @@ template < class ArchTag, class TileShape, class ClusterShape + , uint32_t SchedulerPipelineStageCount > struct TileSchedulerSelector< void, ArchTag, TileShape, ClusterShape + , SchedulerPipelineStageCount > { using Scheduler = typename TileSchedulerSelector< PersistentScheduler, ArchTag, TileShape, ClusterShape + , SchedulerPipelineStageCount >::Scheduler; }; template < class TileShape, class ClusterShape + , uint32_t SchedulerPipelineStageCount > struct TileSchedulerSelector< StreamKScheduler, arch::Sm90, TileShape, ClusterShape + , SchedulerPipelineStageCount > { using Scheduler = PersistentTileSchedulerSm90StreamK; }; @@ -128,6 +139,7 @@ struct TileSchedulerSelector< template < class TileShape, class ClusterShape + , uint32_t SchedulerPipelineStageCount , class GroupProblemShape > struct TileSchedulerSelector< @@ -135,11 +147,113 @@ struct TileSchedulerSelector< arch::Sm90, TileShape, ClusterShape + , SchedulerPipelineStageCount , GroupProblemShape > { using Scheduler = PersistentTileSchedulerSm90Group; }; + +template +struct TileSchedulerSelector< + PersistentScheduler, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = PersistentTileSchedulerSm100< + ClusterShape, + SchedulerPipelineStageCount>; +}; + +// Ptr-Array kernel may provide a specialized ArrayProblemShape type +template +struct TileSchedulerSelector< + PersistentScheduler, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + ProblemShape> { + using Scheduler = PersistentTileSchedulerSm100< + ClusterShape, + SchedulerPipelineStageCount>; +}; + +// Default (void) for Sm100 maps to PersistentTileSchedulerSm100 +template +struct TileSchedulerSelector< + void, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = typename TileSchedulerSelector< + PersistentScheduler, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount>::Scheduler; +}; + +// Default (void) for Sm100 maps to PersistentTileSchedulerSm100 +// Ptr-Array kernel may provide a specialized ArrayProblemShape type +template +struct TileSchedulerSelector< + void, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + ProblemShape> { + using Scheduler = typename TileSchedulerSelector< + PersistentScheduler, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount>::Scheduler; +}; + +// SM100 Group tile scheduler +template < + class TileShape, + class ClusterShape, + uint32_t SchedulerPipelineStageCount, + class GroupProblemShape +> +struct TileSchedulerSelector< + GroupScheduler, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + GroupProblemShape + > { + using Scheduler = PersistentTileSchedulerSm100Group; +}; + +// SM100 stream-K scheduler +template +struct TileSchedulerSelector< + StreamKScheduler, + arch::Sm100, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = PersistentTileSchedulerSm100StreamK< + TileShape, + ClusterShape, + SchedulerPipelineStageCount>; +}; + + + //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 9ac78311d6..aa599a3573 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -189,6 +189,7 @@ struct PersistentTileSchedulerSm90Params { int max_swizzle_size, RasterOrderOptions raster_order_option, bool truncate_by_problem_size=true + , bool bypass_occupancy_calculation=false ) { dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); @@ -199,6 +200,7 @@ struct PersistentTileSchedulerSm90Params { max_swizzle_size, raster_order_option, truncate_by_problem_size + , bypass_occupancy_calculation ); } @@ -214,6 +216,7 @@ struct PersistentTileSchedulerSm90Params { int max_swizzle_size, RasterOrderOptions raster_order_option, bool truncate_by_problem_size=true + , bool bypass_occupancy_calculation=false ) { int const sm_count = hw_info.sm_count; @@ -274,6 +277,7 @@ struct PersistentTileSchedulerSm90Params { } else { int cta_per_device = sm_count; + if (!bypass_occupancy_calculation) { /* * Optimal grid size calculation is based on * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU @@ -281,6 +285,8 @@ struct PersistentTileSchedulerSm90Params { */ constexpr int max_sm_per_gpc = 18; cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); + } + if (raster_order == RasterOrder::AlongN) { launch_grid.y = possibly_truncate( cta_per_device / cluster_shape.m(), @@ -380,7 +386,7 @@ struct PersistentTileSchedulerSm90StreamKParams { // Strategies for computing reductions between CTAs computing portions of a given output tile enum class ReductionMode { // Participating CTAs perform reduction in a turnstile fashion in order of the K extent - // covered by each CTA. This requires a lock to be held exclusively be the CTA that is + // covered by each CTA. This requires a lock to be held exclusively by the CTA that is // currently accumulating. // // Turnstile accumulation ensures deterministic numeric behavior when using this mode. @@ -502,6 +508,32 @@ struct PersistentTileSchedulerSm90StreamKParams { ); } + + // Divides dividend by the cluster size in the M dimension + CUTLASS_HOST_DEVICE + uint64_t + truncate_to_cluster_size_m(uint64_t dividend) const { + if (raster_order_ == RasterOrder::AlongN) { + return divmod_cluster_shape_minor_.divide(dividend) * divmod_cluster_shape_minor_.divisor; + } + else { + return divmod_cluster_shape_major_.divide(dividend) * divmod_cluster_shape_major_.divisor; + } + } + + // Divides dividend by the cluster size in the N dimension + CUTLASS_HOST_DEVICE + uint64_t + truncate_to_cluster_size_n(uint64_t dividend) const { + if (raster_order_ == RasterOrder::AlongM) { + return divmod_cluster_shape_minor_.divide(dividend) * divmod_cluster_shape_minor_.divisor; + } + else { + return divmod_cluster_shape_major_.divide(dividend) * divmod_cluster_shape_major_.divisor; + } + } + + CUTLASS_HOST_DEVICE uint64_t get_cluster_size() const { @@ -542,6 +574,7 @@ struct PersistentTileSchedulerSm90StreamKParams { DecompositionMode decomposition_mode, void* workspace, const uint32_t epilogue_subtile = 1u + , uint32_t ktile_start_alignment_count = 1u ) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl( problem_shape, tile_shape, cluster_shape); @@ -561,6 +594,7 @@ struct PersistentTileSchedulerSm90StreamKParams { decomposition_mode, workspace, epilogue_subtile + , ktile_start_alignment_count ); } @@ -580,6 +614,7 @@ struct PersistentTileSchedulerSm90StreamKParams { DecompositionMode decomposition_mode, void* workspace, const uint32_t epilogue_subtile = 1 + , uint32_t ktile_start_alignment_count = 1u ) { #if !defined(__CUDACC_RTC__) @@ -590,6 +625,7 @@ struct PersistentTileSchedulerSm90StreamKParams { } #endif // !defined(__CUDACC_RTC__) + ktile_start_alignment_count_ = ktile_start_alignment_count; UnderlyingParams underlying_params; underlying_params.initialize( problem_blocks, @@ -716,6 +752,7 @@ struct PersistentTileSchedulerSm90StreamKParams { DecompositionMode decomposition_mode, ReductionMode reduction_mode, const uint32_t epilogue_subtile = 1 + , uint32_t ktile_start_alignment_count = 1u ) { uint32_t groups = 0; uint32_t sk_tiles = 0; @@ -749,6 +786,7 @@ struct PersistentTileSchedulerSm90StreamKParams { decomposition_mode, reduction_mode, epilogue_subtile + , ktile_start_alignment_count ); // Given heuristic_mode returned from the heuristic() method, set params fields. @@ -772,6 +810,7 @@ struct PersistentTileSchedulerSm90StreamKParams { splits, epilogue_subtile, reduction_mode + , ktile_start_alignment_count ); } @@ -797,6 +836,7 @@ struct PersistentTileSchedulerSm90StreamKParams { DecompositionMode decomposition_mode, ReductionMode reduction_mode, uint32_t epilogue_subtile + , uint32_t ktile_start_alignment_count ) { // Get block numbers in m, n and l dimensions @@ -805,6 +845,7 @@ struct PersistentTileSchedulerSm90StreamKParams { // Short circuit to basic split-K decomposition uint32_t adapted_splits = adjust_split_count( splits, hw_info.sm_count, k_tiles_per_output_tile + , ktile_start_alignment_count ); sk_splits = adapted_splits; return DecompositionMode::SplitK; @@ -826,6 +867,8 @@ struct PersistentTileSchedulerSm90StreamKParams { ); uint64_t ctas_per_wave = grid.x * grid.y; cluster_size = cluster_shape.m() * cluster_shape.n(); + uint64_t ctas_per_wave_in_full_clusters = (ctas_per_wave / cluster_size) * cluster_size; + // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. sk_tiles = get_num_sk_tiles( output_tiles, @@ -833,6 +876,7 @@ struct PersistentTileSchedulerSm90StreamKParams { cluster_size, k_tiles_per_output_tile, decomposition_mode + , ctas_per_wave_in_full_clusters ); uint64_t dp_tiles = output_tiles - sk_tiles; // Calculate the number of work units covering the data-parallel and stream-K tiles. @@ -846,6 +890,7 @@ struct PersistentTileSchedulerSm90StreamKParams { dp_units = dp_tiles; uint64_t ctas_per_sk_wave = ctas_per_wave; + ctas_per_sk_wave = ctas_per_wave_in_full_clusters; sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); if (decomposition_mode == DecompositionMode::DataParallel || @@ -924,6 +969,7 @@ struct PersistentTileSchedulerSm90StreamKParams { uint32_t splits, uint32_t epilogue_subtile, ReductionMode reduction_mode + , uint32_t ktile_start_alignment_count ) { // The highest priority when customers set as splitk mode, may set // with a adpated splits value rather than the original splits @@ -1025,6 +1071,7 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle_size, raster_order_option, /* truncate_by_problem_size = */false + /* bypass_occupancy_calculation = */, true ); } @@ -1037,6 +1084,7 @@ struct PersistentTileSchedulerSm90StreamKParams { uint64_t cluster_size, uint32_t k_tiles_per_output_tile, DecompositionMode decomposition_mode + , uint64_t ctas_per_wave_in_full_clusters ) { uint32_t full_waves = static_cast(output_tiles / ctas_per_wave); uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); @@ -1054,16 +1102,16 @@ struct PersistentTileSchedulerSm90StreamKParams { uint64_t dp_tiles = dp_waves * ctas_per_wave; uint64_t sk_tiles = output_tiles - dp_tiles; - if (decomposition_mode == DecompositionMode::Heuristic) { - if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) { - // All tiles will be data-parallel tiles if there is either no quantization - // or if there is no work to be split. - return 0; - } - // - // The final wave is not full. Perform some stream-K work. - // + if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) { + // All tiles will be data-parallel tiles if there is either no quantization + // or if there is no work to be split. + return 0; + } + // + // The final wave is not full. Perform some stream-K work. + // + if (decomposition_mode == DecompositionMode::Heuristic) { // Rudimentary heuristic: prefer data-parallel decomposition if we have more than // one wave and the tail wave is more than half full. This is subject to change. uint64_t tail_tiles = output_tiles - (full_waves * ctas_per_wave); @@ -1071,7 +1119,6 @@ struct PersistentTileSchedulerSm90StreamKParams { return 0; } } - return static_cast(sk_tiles); } @@ -1172,14 +1219,17 @@ struct PersistentTileSchedulerSm90StreamKParams { ); uint64_t ctas_per_wave = grid.x * grid.y; uint64_t cluster_size = cluster_shape.m() * cluster_shape.n(); + uint64_t ctas_per_wave_in_full_clusters = (ctas_per_wave / cluster_size) * cluster_size; uint32_t sk_tiles = get_num_sk_tiles( output_tiles, ctas_per_wave, cluster_size, static_cast(k_tiles_per_output_tile), decomposition_mode + , ctas_per_wave_in_full_clusters ); uint64_t ctas_per_sk_wave = ctas_per_wave; + ctas_per_sk_wave = ctas_per_wave_in_full_clusters; uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); uint64_t dp_tiles = output_tiles - sk_tiles; @@ -1187,11 +1237,13 @@ struct PersistentTileSchedulerSm90StreamKParams { (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { splits = adjust_split_count( splits, new_hw_info.sm_count, k_tiles_per_output_tile + , ktile_start_alignment_count ); } bool split_k_required = splits > 1 && (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic); - bool split_k_selected = decomposition_mode == DecompositionMode::Heuristic && + bool split_k_selected = !split_k_required && + decomposition_mode == DecompositionMode::Heuristic && sk_units > sk_tiles && sk_tiles != 0 && sk_units % sk_tiles == 0; @@ -1547,6 +1599,7 @@ struct PersistentTileSchedulerSm90StreamKParams { int splits, int sm_count, uint32_t k_tiles_per_output_tile + , uint32_t ktile_start_alignment_count ) { // Don't split by more than the available number of SMs if (splits > sm_count) { @@ -1561,6 +1614,11 @@ struct PersistentTileSchedulerSm90StreamKParams { // If k_tiles_per_output_tiles / splits == 1, there will be one k_tile per cta // and this violate k_tile start from even requirements. Thus we need to // reduce the number of splits. + if (ktile_start_alignment_count > 1u && + splits > 1 && + k_tiles_per_output_tile / static_cast(splits) == 1) { + splits = k_tiles_per_output_tile / ktile_start_alignment_count; + } return splits; } }; @@ -1809,6 +1867,732 @@ struct PersistentTileSchedulerSm90GroupParams { }; //////////////////////////////////////////////////////////////////////////////// + + +// +// Parameters for SM100 tile schedulers +// + +// Parameters for SM100 persistent tile scheduler +struct PersistentTileSchedulerSm100Params { + + using UnderlyingParams = PersistentTileSchedulerSm90Params; + + using RasterOrder = UnderlyingParams::RasterOrder; + using RasterOrderOptions = UnderlyingParams::RasterOrderOptions; + + uint32_t problem_tiles_m_ = 0; + uint32_t problem_tiles_n_ = 0; + uint32_t problem_tiles_l_ = 0; + FastDivmod divmod_cluster_shape_m_{}; + FastDivmod divmod_cluster_shape_n_{}; + RasterOrder raster_order_ = RasterOrder::AlongM; + int32_t log_swizzle_size_ = 0; + + // Initializes members. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + void + initialize( + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option + ) { + dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); + initialize( + problem_blocks, + cluster_shape, + hw_info, + max_swizzle_size, + raster_order_option + ); + } + + // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + void + initialize( + dim3 problem_blocks, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option + ) { + + CUTLASS_UNUSED(hw_info); + CUTLASS_UNUSED(max_swizzle_size); + + // Cluster counters in m, n and l dimensions of the problem tiles + problem_tiles_m_ = problem_blocks.x / cluster_shape.m(); + problem_tiles_n_ = problem_blocks.y / cluster_shape.n(); + problem_tiles_l_ = problem_blocks.z; + divmod_cluster_shape_m_ = FastDivmod(cluster_shape.m()); + divmod_cluster_shape_n_ = FastDivmod(cluster_shape.n()); + + raster_order_ = UnderlyingParams::get_rasterization_order(problem_tiles_m_, problem_tiles_n_, raster_order_option); + if (raster_order_option == RasterOrderOptions::Heuristic && raster_order_ == RasterOrder::AlongN) { + // The current implementation of AlongN rasterization for B100 requires swapping the number of clusters along the + // X and Y dimensions of the grid. However, since the grid Y dimension has a smaller range of allowed values + // than the grid X dimension, we must check whether the swapped grid would exceed the grid Y limit. If the + // swapped grid would exceed this limit, simply rever to AlongM mode. + // + // Overflow in the swapped X dimension is not possible. At worst, there will be ((1 << 16) - 1) clusters + // along the original Y dimension of the grid. Even if the cluster M mode is 16, the new grid X value + // will be at most ((1 << 16) - 1) * 16, which is less than the grid X limit of ((1 << 31) - 1). + uint32_t cluster_m = static_cast(problem_blocks.x) / static_cast(cluster_shape.m()); + uint32_t new_grid_y = cluster_m * static_cast(cluster_shape.n()); + + if (new_grid_y > (1 << 16) - 1) { + raster_order_ = RasterOrder::AlongM; + } + } + } + + // Given the inputs, computes the physical grid we should launch. + // This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + BatchedGemmCoord problem_shape, + GemmCoord cta_shape, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option + ) { + + CUTLASS_UNUSED(cluster_shape); + CUTLASS_UNUSED(hw_info); + CUTLASS_UNUSED(max_swizzle_size); + CUTLASS_UNUSED(raster_order_option); + + return get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); + } + + // Get the number of CTA tiles in this problem. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl( + BatchedGemmCoord problem_shape, + GemmCoord cta_shape, + GemmCoord cluster_shape) { + + return UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); + } + + // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + static size_t + get_workspace_size( + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle, + RasterOrderOptions raster_order_option + ) { + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); + return get_workspace_size( + problem_blocks, + cluster_shape, + hw_info, + max_swizzle, + raster_order_option + ); + } + + // Version of get_workspace_size that takes in as input the number of CTAs in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + static size_t + get_workspace_size( + dim3 problem_blocks, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle, + RasterOrderOptions raster_order_option + ) { + + CUTLASS_UNUSED(problem_blocks); + CUTLASS_UNUSED(cluster_shape); + CUTLASS_UNUSED(hw_info); + CUTLASS_UNUSED(max_swizzle); + CUTLASS_UNUSED(raster_order_option); + + return 0; + } + + // Initialize the workspace to be used for the kernel. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + static cutlass::Status + initialize_workspace( + void* workspace, + cudaStream_t stream, + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle, + RasterOrderOptions raster_order_option, + CudaHostAdapter *cuda_adapter = nullptr + ) { + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); + return initialize_workspace( + workspace, + stream, + problem_blocks, + cluster_shape, + hw_info, + max_swizzle, + raster_order_option, + cuda_adapter + ); + } + + // Version of initialize_workspace that takes in as input the number of CTAs in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + static cutlass::Status + initialize_workspace( + void* workspace, + cudaStream_t stream, + dim3 problem_blocks, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle, + RasterOrderOptions raster_order_option, + CudaHostAdapter *cuda_adapter = nullptr + ) { + + CUTLASS_UNUSED(workspace); + CUTLASS_UNUSED(stream); + CUTLASS_UNUSED(problem_blocks); + CUTLASS_UNUSED(cluster_shape); + CUTLASS_UNUSED(hw_info); + CUTLASS_UNUSED(max_swizzle); + CUTLASS_UNUSED(raster_order_option); + + return cutlass::Status::kSuccess; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Parameters for SM100 persistent stream-K tile scheduler +struct PersistentTileSchedulerSm100StreamKParams { + + using UnderlyingParams = PersistentTileSchedulerSm100Params; + using UnderlyingStreamKParams = PersistentTileSchedulerSm90StreamKParams; + using RasterOrderOptions = UnderlyingParams::RasterOrderOptions; + using ReductionMode = UnderlyingStreamKParams::ReductionMode; + using DecompositionMode = UnderlyingStreamKParams::DecompositionMode; + + using RasterOrder = UnderlyingParams::RasterOrder; + RasterOrder raster_order_ = RasterOrder::AlongM; + int32_t log_swizzle_size_ = 0; + + UnderlyingStreamKParams sk_params_{}; + UnderlyingParams sm100_params_{}; + + // Initializes members. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + void + initialize( + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int splits, + int max_swizzle_size, + RasterOrderOptions raster_order_option, + ReductionMode reduction_mode, + DecompositionMode decomposition_mode, + void* workspace, + uint32_t ktile_start_alignment_count = 1u + ) { + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); + + // Number of k tiles in each output tile + uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); + + initialize( + problem_blocks, + k_tiles_per_output_tile, + cluster_shape, + hw_info, + splits, + max_swizzle_size, + raster_order_option, + reduction_mode, + decomposition_mode, + workspace, + ktile_start_alignment_count + ); + } + + // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + void + initialize( + dim3 problem_blocks, + uint32_t k_tile_per_output_tile, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int splits, + int max_swizzle_size, + RasterOrderOptions raster_order_option, + ReductionMode reduction_mode, + DecompositionMode decomposition_mode, + void* workspace, + uint32_t ktile_start_alignment_count = 1u + ) { + sk_params_.initialize( + problem_blocks, + k_tile_per_output_tile, + cluster_shape, + hw_info, + splits, + max_swizzle_size, + raster_order_option, + reduction_mode, + decomposition_mode, + workspace, + /*epilogue_subtile=*/1, + ktile_start_alignment_count + ); + + log_swizzle_size_ = sk_params_.log_swizzle_size_; + raster_order_ = sk_params_.raster_order_; + + sm100_params_.initialize( + problem_blocks, + cluster_shape, + hw_info, + max_swizzle_size, + RasterOrderOptions::AlongM // Override raster_order to be AlongM, since the SM100 stream-K scheduler does not require grid swapping for raster order selection + ); + } + + // Get the number of CTA tiles in this problem. + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl( + BatchedGemmCoord problem_shape, + GemmCoord cta_shape, + GemmCoord cluster_shape) { + + return UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); + } + + // Given the inputs, computes the physical grid we should launch. + // This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + CUTLASS_HOST_DEVICE + dim3 + get_grid_shape(BatchedGemmCoord problem_shape, GemmCoord cta_shape, GemmCoord cluster_shape) const { + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); + + return get_grid_shape(problem_blocks, cluster_shape); + } + + // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + CUTLASS_HOST_DEVICE + dim3 + get_grid_shape(dim3 problem_blocks, GemmCoord cluster_shape) const { + if (sk_params_.sk_units_ > 0) { + // For stream-K cases, we would, ideally, launch a linear grid of size `sk_params_.units_per_problem_`. + // However doing so raises two potential issues: + // (a) the total number of tiles in the kernel may exceed the amount that can fit in a single + // returned value of a CLC query + // (b) the launched grid would not respect cluster-size divisibility requirements + // + // To circumvent these issues, we must distribute the `sk_params_.units_per_problem_` units of work + // across the X, Y, and Z dimensions of the grid, while ensuring that the X and Y dimensions are + // divisible by cluster size (we ignore Z, as all CUTLASS kernels currently use a cluster shape + // of 1 in the Z dimension). + // + // For convenience, we launch this as "waves" of `sk_params_.sk_units_` CTAs, with the wave count being + // the Z dimension of the grid, and the `sk_params_.sk_units_` CTAs per wave being distributed across + // the X and Y dimensions of the grid in a way that alingns with cluster divisibility requirements. + // + // Thus, the grid that is launched looks like: + // grid = dim3(sk_units_ / cluster.y, cluster.y, waves) + // + // We place sk_units_ / cluster.y in the X dimension of the grid because the CLC query feature + // allocates more bits for the X index values returned in the query. + // + + // For most cases, `sk_params_.sk_units_` will equal the number of available SMs, so this grid will + // naturally represent waves in the true hardware sense. + // + // However, there are some corner cases in which fewer stream-K units are used than the full SM count + // (e.g., if using the full SM count would result in stream-K units that are assigned fewer than the + // minimum number of K tile iterations). In these cases, `sk_params_.units_per_problem_` may not be + // divisible by `sk_params_.sk_units_`, since any data-parallel work performed alongside stream-K + // work is always done in terms of waves of CTAs of number equal to the number of available SMs. + // Therefore, we take the ceiling of the division when determining wave count, and allow the underlying + // stream-K scheduler to determine which indices are in bounds. + uint32_t waves = static_cast( + (sk_params_.units_per_problem_ + sk_params_.sk_units_ - 1) / sk_params_.sk_units_); + + return dim3( + sk_params_.sk_units_ / cluster_shape.n(), + cluster_shape.n(), + waves + ); + } + else { + // Grid launch for data-parallel and basic split-K decomposition. When data-parallel + // mode is used, params.sk_params_.splits = 1. + return dim3(problem_blocks.x, problem_blocks.y, problem_blocks.z * sk_params_.divmod_splits_.divisor); + } + } + + // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + static size_t + get_workspace_size( + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + uint32_t reduction_warp_groups, + uint32_t barrier_bits, + uint32_t element_accumulator_bits, + uint32_t ktile_start_alignment_count = 1 + ) { + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); + uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); + + return get_workspace_size( + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + cluster_shape, + hw_info, + splits, + max_swizzle, + raster_order_option, + decomposition_mode, + reduction_mode, + reduction_warp_groups, + barrier_bits, + element_accumulator_bits, + ktile_start_alignment_count + ); + } + + // Version of get_workspace_size that takes in as input the number of CTAs in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + static size_t + get_workspace_size( + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + uint32_t reduction_warp_groups, + uint32_t barrier_bits, + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile = 1, + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1 + ) { + return UnderlyingStreamKParams::get_workspace_size( + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + cluster_shape, + hw_info, + splits, + max_swizzle, + raster_order_option, + decomposition_mode, + reduction_mode, + reduction_warp_groups, + barrier_bits, + element_accumulator_bits, + epilogue_subtile, + num_accumulator_mtxs, + ktile_start_alignment_count + ); + } + + // Initialize the workspace to be used for the kernel. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + static cutlass::Status + initialize_workspace( + void* workspace, + cudaStream_t stream, + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + uint32_t reduction_warp_groups, + uint32_t barrier_bits, + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile = 1, + uint32_t num_accumulator_mtxs = 1, + CudaHostAdapter *cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1 + ) { + dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); + uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); + + return initialize_workspace( + workspace, + stream, + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + cluster_shape, + hw_info, + splits, + max_swizzle, + raster_order_option, + decomposition_mode, + reduction_mode, + reduction_warp_groups, + barrier_bits, + element_accumulator_bits, + epilogue_subtile, + num_accumulator_mtxs, + cuda_adapter, + ktile_start_alignment_count + ); + } + + // Version of initialize_workspace that takes in as input the number of CTAs in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + static cutlass::Status + initialize_workspace( + void* workspace, + cudaStream_t stream, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord tile_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + uint32_t reduction_warp_groups, + uint32_t barrier_bits, + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile = 1, + uint32_t num_accumulator_mtxs = 1, + CudaHostAdapter *cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1 + ) { + return UnderlyingStreamKParams::initialize_workspace( + workspace, + stream, + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + cluster_shape, + hw_info, + splits, + max_swizzle, + raster_order_option, + decomposition_mode, + reduction_mode, + reduction_warp_groups, + barrier_bits, + element_accumulator_bits, + epilogue_subtile, + num_accumulator_mtxs, + cuda_adapter, + ktile_start_alignment_count + ); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Parameters for SM100 persistent group scheduler (only used for Grouped Gemms) +template +struct PersistentTileSchedulerSm100GroupParams { + + using UnderlyingSm90Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = typename UnderlyingSm90Params::RasterOrder; + using RasterOrderOptions = typename UnderlyingSm90Params::RasterOrderOptions; + + UnderlyingSm90Params params_sm90_{}; + + // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + void + initialize( + dim3 problem_blocks, + int32_t groups, + ProblemShape* problem_shapes, + ProblemShape const* host_problem_shapes, + GemmCoord cta_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option + ) { + + params_sm90_.initialize( + problem_blocks, + groups, + problem_shapes, + host_problem_shapes, + cta_shape, + cluster_shape, + hw_info, + max_swizzle_size, + raster_order_option + ); + } + + // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl(GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) { + return UnderlyingSm90Params::get_tiled_cta_shape_mnl(cluster_shape, cta_m, cta_n); + } + + // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + dim3 problem_blocks, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option, + bool truncate_by_problem_size = true, + bool is_static_cluster_shape = false) { + + int const sm_count = hw_info.sm_count; + + // Round up to nearest multiple of swizzle_size along each mode + auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); + auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + + int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; + + RasterOrder raster_order = get_rasterization_order( + problem_blocks_m, + problem_blocks_n, + raster_order_option + ); + + dim3 launch_grid; + + if (raster_order == RasterOrder::AlongN) { + launch_grid = dim3(cluster_shape.m(), 1, 1); + } + else { + launch_grid = dim3(1, cluster_shape.n(), 1); + } + + auto possibly_truncate = [&](int x, int y) { + if (truncate_by_problem_size) { + return platform::min(x, y); + } + else { + return x; + } + }; + + if (is_static_cluster_shape) { + // The else path is generic, however, we can avoid some divs if we know cluster size is 1 + auto cluster_size = cluster_shape.m() * cluster_shape.n(); + if (cluster_size == 1) { + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = possibly_truncate(sm_count, problem_blocks_total); + } + else { + launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); + } + } + else { + constexpr int max_sm_per_gpc = 20; + int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = possibly_truncate( + cta_per_device / cluster_shape.m(), + problem_blocks_total / cluster_shape.m()); + } + else { + launch_grid.x = possibly_truncate( + cta_per_device / cluster_shape.n(), + problem_blocks_total / cluster_shape.n()); + } + CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using heuristics = " + "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); + } + } + else { + // With preferred clusters, we can launch the largest possible persistent grid (rounded up to cluster dims) + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = ((possibly_truncate(sm_count, problem_blocks_total) / cluster_shape.m()) / cluster_shape.n()) * cluster_shape.n(); + } + else { + launch_grid.x = ((possibly_truncate(sm_count, problem_blocks_total) / cluster_shape.n()) / cluster_shape.m()) * cluster_shape.m(); + } + CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using preferred clusters = " + "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); + } + return launch_grid; + } + + CUTLASS_HOST_DEVICE + static int32_t + get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { + return UnderlyingSm90Params::get_log_swizzle_size(problem_ctas_m, problem_ctas_n, max_swizzle_size); + } + + CUTLASS_HOST_DEVICE + static RasterOrder + get_rasterization_order( + uint32_t tiles_m, + uint32_t tiles_n, + RasterOrderOptions raster_order_option + ) { + return UnderlyingSm90Params::get_rasterization_order(tiles_m, tiles_n, raster_order_option); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + } // namespace detail } // namespace kernel } // namespace gemm diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 80ddefa121..8b61412568 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -79,6 +79,10 @@ struct integer_subbyte { integer_subbyte(T value) : integer_subbyte(static_cast(value)) {} + CUTLASS_HOST_DEVICE + integer_subbyte(float value) + : integer_subbyte(static_cast(value)) {} + // CUTLASS code commonly converts both signed and unsigned integers // into integer_subbyte, so the class provides both explicit // conversions. @@ -203,6 +207,11 @@ using int4b_t = integer_subbyte<4, true>; /// 4-bit Unsigned integer type using uint4b_t = integer_subbyte<4, false>; + +/// 6-bit unsigned integer type +using uint6b_t = integer_subbyte<6, false>; + + /// 1-bit binary type using bin1_t = bool; diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index 1d61904a9a..c24e2bab88 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -51,6 +51,9 @@ struct KernelHardwareInfo { // Kernel properties int max_active_clusters = 0; // Maximum number of clusters that could co-exist on the target device. + dim3 cluster_shape = {0,0,0}; + dim3 cluster_shape_fallback = {0,0,0}; + // // Methods // diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index d708fd7ab3..fde8eb0770 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -2102,19 +2102,14 @@ struct NumericArrayConverterPacked4Element { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { +struct NumericArrayConverterPacked4Element { using result_element = float; - using source_element = cutlass::float_e5m2_t; + using source_element = float_ue4m3_t; using result_type = Array; using source_type = Array; @@ -2131,8 +2126,8 @@ struct NumericArrayConverterPacked4Element "{\n" \ ".reg .b16 lo, hi;\n" \ "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ - "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); @@ -2163,12 +2158,12 @@ struct NumericArrayConverterPacked4Element } }; -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; +struct NumericArrayConverterPacked4Element { + using result_element = float_ue4m3_t; using source_element = float; using result_type = Array; @@ -2185,8 +2180,8 @@ struct NumericArrayConverterPacked4Element { "{\n" \ ".reg .b16 lo;\n" \ ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ - "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ "mov.b32 %0, {lo, hi};\n" \ "}" \ : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); @@ -2213,36 +2208,47 @@ struct NumericArrayConverterPacked4Element { ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Partial specializations for Array <=> Array +// Partial specializations for Array <=> Array // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e4m3_t; +struct NumericArrayConverterPacked4Element { + using result_element = float; + using source_element = float_ue8m0_t; using result_type = Array; using source_type = Array; + using BfloatArr = Array; static FloatRoundStyle const round_style = Round; CUTLASS_DEVICE static result_type convert(source_type const & source) { - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out[2]; + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint32_t out_fp16[2]; uint32_t const& src_packed = reinterpret_cast(source); asm volatile( \ "{\n" \ ".reg .b16 lo, hi;\n" \ "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ - "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); - return reinterpret_cast(out); + "cvt.rn.bf16x2.ue8m0x2 %0, lo;\n" \ + "cvt.rn.bf16x2.ue8m0x2 %1, hi;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); + + NumericArrayConverter bf2fp32_converter; + auto res0 = bf2fp32_converter(reinterpret_cast &>(out_fp16[0])); + auto res1 = bf2fp32_converter(reinterpret_cast &>(out_fp16[1])); + + result_type out; + out[0] = res0[0]; + out[1] = res0[1]; + out[2] = res1[0]; + out[3] = res1[1]; + return out; #else result_type result; NumericConverter converter; @@ -2262,45 +2268,41 @@ struct NumericArrayConverterPacked4Element <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::half_t; + +/// Partial specialization for Array <= Array +template <> +struct NumericArrayConverterPacked4Element { + using result_element = float_ue8m0_t; + using source_element = float; using result_type = Array; using source_type = Array; - static FloatRoundStyle const round_style = Round; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_infinity; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - #if defined(CUDA_PTX_FP8_CVT_ENABLED) + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) uint32_t out; - uint32_t const* src_packed = reinterpret_cast(&source); - asm volatile( \ "{\n" \ ".reg .b16 lo;\n" \ ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ - "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ + "cvt.rp.satfinite.ue8m0x2.f32 lo, %2, %1;\n" \ + "cvt.rp.satfinite.ue8m0x2.f32 hi, %4, %3;\n" \ "mov.b32 %0, {lo, hi};\n" \ "}" \ - : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); return reinterpret_cast(out); #else result_type result; - NumericConverter converter; + NumericConverter converter; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 4; ++i) { result[i] = converter(source[i]); } - return result; #endif } @@ -2311,41 +2313,35 @@ struct NumericArrayConverterPacked4Element } }; -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array <=> Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::half_t; - using source_element = cutlass::float_e5m2_t; +/// Partial specialization for Array <= Array +template <> +struct NumericArrayConverterPacked4Element { + using result_element = float_ue8m0_t; + using source_element = float; using result_type = Array; using source_type = Array; - static FloatRoundStyle const round_style = Round; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out[2]; - uint32_t const& src_packed = reinterpret_cast(source); + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint32_t out; asm volatile( \ "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ - "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ - "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rz.satfinite.ue8m0x2.f32 lo, %2, %1;\n" \ + "cvt.rz.satfinite.ue8m0x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); + return reinterpret_cast(out); #else result_type result; - NumericConverter converter; + NumericConverter converter; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 4; ++i) { @@ -2362,47 +2358,21 @@ struct NumericArrayConverterPacked4Element <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::half_t; +struct NumericArrayConverterPacked4Element { + using result_element = float_ue8m0_t; + using source_element = float; using result_type = Array; using source_type = Array; - static FloatRoundStyle const round_style = Round; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_infinity; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - - #if defined(CUDA_PTX_FP8_CVT_ENABLED) - uint32_t out; - uint32_t const* src_packed = reinterpret_cast(&source); - - asm volatile( \ - "{\n" \ - ".reg .b16 lo;\n" \ - ".reg .b16 hi;\n" \ - "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ - "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ - "mov.b32 %0, {lo, hi};\n" \ - "}" \ - : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); - - return reinterpret_cast(out); - #else - result_type result; - NumericConverter converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - result[i] = converter(source[i]); - } - - return result; - #endif + //default maps to RP mode. + return NumericArrayConverterPacked4Element{}(source); } CUTLASS_HOST_DEVICE @@ -2411,19 +2381,21 @@ struct NumericArrayConverterPacked4Element } }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Partial specializations for Array <=> Array +// Partial specializations for Array <=> Array // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::bfloat16_t; - using source_element = cutlass::float_e4m3_t; +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::detail::float_e2m3_unpack8bits_t; + using source_element = float; using result_type = Array; using source_type = Array; @@ -2432,20 +2404,20 @@ struct NumericArrayConverterPacked4Element src2float; - Array tmp_floats = src2float(source); + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t out; - // Convert float to bf16 - result_type out; - Array* packed_tmp = reinterpret_cast*>(&tmp_floats); - Array* packed_out = reinterpret_cast*>(&out); - NumericArrayConverter float2result; - packed_out[0] = float2result(packed_tmp[0]); - packed_out[1] = float2result(packed_tmp[1]); + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e2m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e2m3x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - return out; + return reinterpret_cast(out); #else result_type result; NumericConverter converter; @@ -2465,13 +2437,13 @@ struct NumericArrayConverterPacked4Element <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; - using source_element = cutlass::bfloat16_t; +struct NumericArrayConverterPacked4Element { + using result_element = float; + using source_element = cutlass::detail::float_e2m3_unpack8bits_t; using result_type = Array; using source_type = Array; @@ -2480,18 +2452,27 @@ struct NumericArrayConverterPacked4Element tmp; - Array* packed_tmp = reinterpret_cast*>(&tmp); - Array const* packed_source = reinterpret_cast const*>(&source); - NumericArrayConverter src2float; - packed_tmp[0] = src2float(packed_source[0]); - packed_tmp[1] = src2float(packed_source[1]); + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t out_fp16[2]; + uint32_t const& src_packed = reinterpret_cast(source); - // Convert float to f8 - NumericArrayConverterPacked4Element float2result; - return float2result(tmp); + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e2m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e2m3x2 %1, hi;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); + float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + out[2] = res1.x; + out[3] = res1.y; + return out; #else result_type result; NumericConverter converter; @@ -2513,17 +2494,17 @@ struct NumericArrayConverterPacked4Element <=> Array +// Partial specializations for Array <=> Array // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::bfloat16_t; - using source_element = cutlass::float_e5m2_t; +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::detail::float_e3m2_unpack8bits_t; + using source_element = float; using result_type = Array; using source_type = Array; @@ -2532,20 +2513,20 @@ struct NumericArrayConverterPacked4Element src2float; - Array tmp_floats = src2float(source); + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t out; - // Convert float to bf16 - result_type out; - Array* packed_tmp = reinterpret_cast*>(&tmp_floats); - Array* packed_out = reinterpret_cast*>(&out); - NumericArrayConverter float2result; - packed_out[0] = float2result(packed_tmp[0]); - packed_out[1] = float2result(packed_tmp[1]); + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e3m2x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e3m2x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); - return out; + return reinterpret_cast(out); #else result_type result; NumericConverter converter; @@ -2565,13 +2546,14 @@ struct NumericArrayConverterPacked4Element <= Array + +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::bfloat16_t; +struct NumericArrayConverterPacked4Element { + using result_element = float; + using source_element = cutlass::detail::float_e3m2_unpack8bits_t; using result_type = Array; using source_type = Array; @@ -2580,18 +2562,27 @@ struct NumericArrayConverterPacked4Element tmp; - Array* packed_tmp = reinterpret_cast*>(&tmp); - Array const* packed_source = reinterpret_cast const*>(&source); - NumericArrayConverter src2float; - packed_tmp[0] = src2float(packed_source[0]); - packed_tmp[1] = src2float(packed_source[1]); + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t out_fp16[2]; + uint32_t const& src_packed = reinterpret_cast(source); - // Convert float to f8 - NumericArrayConverterPacked4Element float2result; - return float2result(tmp); + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e3m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e3m2x2 %1, hi;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); + float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + out[2] = res1.x; + out[3] = res1.y; + return out; #else result_type result; NumericConverter converter; @@ -2611,18 +2602,19 @@ struct NumericArrayConverterPacked4Element <=> Array +// Partial specializations for Array <=> Array // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { - using result_element = cutlass::float_e4m3_t; +struct NumericArrayConverterPacked4Element { + using result_element = float; using source_element = cutlass::float_e5m2_t; using result_type = Array; @@ -2631,6 +2623,29 @@ struct NumericArrayConverterPacked4Element(source); + + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); + float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + out[2] = res1.x; + out[3] = res1.y; + return out; + #else result_type result; NumericConverter converter; @@ -2640,6 +2655,7 @@ struct NumericArrayConverterPacked4Element <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverterPacked4Element { +struct NumericArrayConverterPacked4Element { using result_element = cutlass::float_e5m2_t; - using source_element = cutlass::float_e4m3_t; + using source_element = float; using result_type = Array; using source_type = Array; @@ -2662,6 +2678,22 @@ struct NumericArrayConverterPacked4Element(out); + #else result_type result; NumericConverter converter; @@ -2671,6 +2703,7 @@ struct NumericArrayConverterPacked4Element <=> Array -// Array <=> Array -// using packed converter under the hood +// Partial specializations for Array <=> Array // ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array template < - typename T, - typename S, - int N, FloatRoundStyle Round > -struct PackedNumericArrayConverter { - using result_element = T; - using source_element = S; - - using result_type = Array; - using source_type = Array; +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::half_t; + using source_element = cutlass::float_e4m3_t; + using result_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; -private: - using packed_result_type = Array; - using packed_source_type = Array; - -public: CUTLASS_DEVICE static result_type convert(source_type const & source) { - result_type result; - packed_result_type* packed_result = reinterpret_cast(&result); - const packed_source_type* packed_source = reinterpret_cast(&source); - - detail::NumericArrayConverterPacked4Element packed_converter; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - packed_result[i] = packed_converter(packed_source[i]); - } - // Handle leftovers + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out[2]; + uint32_t const& src_packed = reinterpret_cast(source); + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ + "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); + return reinterpret_cast(out); + #else + result_type result; NumericConverter converter; + CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N % 4; ++i) { - int idx = ((N / 4) * 4) + i; - result[idx] = converter(source[idx]); + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); } return result; + #endif } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const{ + result_type operator()(source_type const &s) const { return convert(s); } }; -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < - typename T, - int N, FloatRoundStyle Round > -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::float_e4m3_t; + using source_element = cutlass::half_t; -/// Partial specialization for Array <= Array -template < - typename T, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + CUTLASS_DEVICE + static result_type convert(source_type const & source) { -/// Partial specialization for Array <= Array -template < - typename S, - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out; + uint32_t const* src_packed = reinterpret_cast(&source); -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter : - public PackedNumericArrayConverter {}; + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for Array <= Array -/// Conversion is performed with saturation regardless of setting of -/// the `Round` template parameter. +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::half_t; + using source_element = cutlass::float_e5m2_t; - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - NumericConverter destination_converter; + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out[2]; + uint32_t const& src_packed = reinterpret_cast(source); + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); + return reinterpret_cast(out); + #else result_type result; - result[0] = destination_converter(source[0]); + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + return result; + #endif } CUTLASS_HOST_DEVICE @@ -2836,21 +2863,47 @@ struct NumericArrayConverter { } }; +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::float_e5m2_t; + using source_element = cutlass::half_t; - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - NumericConverter destination_converter; + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out; + uint32_t const* src_packed = reinterpret_cast(&source); + + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); + + return reinterpret_cast(out); + #else result_type result; - result[0] = destination_converter(source[0]); + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + return result; + #endif } CUTLASS_HOST_DEVICE @@ -2859,31 +2912,52 @@ struct NumericArrayConverter { } }; -// To convert a FP32 to Int that has less than 32 bits, we need to convert it to int32 first. +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array template < - typename T, - int N, FloatRoundStyle Round > -struct NumericArrayFP32ToIntConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::bfloat16_t; + using source_element = cutlass::float_e4m3_t; - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; - static_assert(cutlass::platform::numeric_limits::is_integer, "the dest type has to be int."); - - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - // Convert float to int - Array temporary; - NumericArrayConverter compute_converter; - temporary = compute_converter(source); + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert f8 to float + NumericArrayConverterPacked4Element src2float; + Array tmp_floats = src2float(source); - // Convert to int to int8_t - NumericArrayConverter destination_converter; - return destination_converter(temporary); + // Convert float to bf16 + result_type out; + Array* packed_tmp = reinterpret_cast*>(&tmp_floats); + Array* packed_out = reinterpret_cast*>(&out); + NumericArrayConverter float2result; + packed_out[0] = float2result(packed_tmp[0]); + packed_out[1] = float2result(packed_tmp[1]); + + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } CUTLASS_HOST_DEVICE @@ -2892,20 +2966,44 @@ struct NumericArrayFP32ToIntConverter { } }; - +/// Partial specialization for Array <= Array template < - int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::float_e4m3_t; + using source_element = cutlass::bfloat16_t; - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert bf16 to float + Array tmp; + Array* packed_tmp = reinterpret_cast*>(&tmp); + Array const* packed_source = reinterpret_cast const*>(&source); + NumericArrayConverter src2float; + packed_tmp[0] = src2float(packed_source[0]); + packed_tmp[1] = src2float(packed_source[1]); + + // Convert float to f8 + NumericArrayConverterPacked4Element float2result; + return float2result(tmp); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } CUTLASS_HOST_DEVICE @@ -2914,19 +3012,52 @@ struct NumericArrayConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array template < - int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::bfloat16_t; + using source_element = cutlass::float_e5m2_t; - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert f8 to float + NumericArrayConverterPacked4Element src2float; + Array tmp_floats = src2float(source); + + // Convert float to bf16 + result_type out; + Array* packed_tmp = reinterpret_cast*>(&tmp_floats); + Array* packed_out = reinterpret_cast*>(&out); + NumericArrayConverter float2result; + packed_out[0] = float2result(packed_tmp[0]); + packed_out[1] = float2result(packed_tmp[1]); + + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } CUTLASS_HOST_DEVICE @@ -2935,19 +3066,44 @@ struct NumericArrayConverter { } }; +/// Partial specialization for Array <= Array template < - int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::float_e5m2_t; + using source_element = cutlass::bfloat16_t; - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert bf16 to float + Array tmp; + Array* packed_tmp = reinterpret_cast*>(&tmp); + Array const* packed_source = reinterpret_cast const*>(&source); + NumericArrayConverter src2float; + packed_tmp[0] = src2float(packed_source[0]); + packed_tmp[1] = src2float(packed_source[1]); + + // Convert float to f8 + NumericArrayConverterPacked4Element float2result; + return float2result(tmp); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } CUTLASS_HOST_DEVICE @@ -2956,19 +3112,35 @@ struct NumericArrayConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array template < - int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::float_e4m3_t; + using source_element = cutlass::float_e5m2_t; - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - NumericArrayFP32ToIntConverter converter; - return converter(source); + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; } CUTLASS_HOST_DEVICE @@ -2977,39 +3149,29 @@ struct NumericArrayConverter { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ - ((__CUDACC_VER_MAJOR__ > 10) || \ - ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) - -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { + using result_element = cutlass::float_e5m2_t; + using source_element = cutlass::float_e4m3_t; - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { + result_type result; + NumericConverter converter; - unsigned out; - - asm volatile( - "{ .reg .u32 r4;" - "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" - "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" - "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" - "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" - "}" - : "=r"(out) - : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), - "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } - return reinterpret_cast(out); + return result; } CUTLASS_HOST_DEVICE @@ -3018,69 +3180,178 @@ struct NumericArrayConverter { } }; -/// Partial specialization for Array <= Array +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for: +// Array <=> Array +// Array <=> Array +// using packed converter under the hood +// +///////////////////////////////////////////////////////////////////////////////////////////////// + template < + typename T, + typename S, int N, FloatRoundStyle Round > -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); +struct PackedNumericArrayConverter { + using result_element = T; + using source_element = S; - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; + using result_type = Array; + using source_type = Array; - CUTLASS_HOST_DEVICE - static result_type convert(source_type const & source) { + static FloatRoundStyle const round_style = Round; - NumericArrayConverter convert_vector_; +private: + using packed_result_type = Array; + using packed_source_type = Array; +public: + CUTLASS_DEVICE + static result_type convert(source_type const & source) { result_type result; + packed_result_type* packed_result = reinterpret_cast(&result); + const packed_source_type* packed_source = reinterpret_cast(&source); - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); + detail::NumericArrayConverterPacked4Element packed_converter; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); + for (int i = 0; i < N / 4; ++i) { + packed_result[i] = packed_converter(packed_source[i]); + } + + // Handle leftovers + NumericConverter converter; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N % 4; ++i) { + int idx = ((N / 4) * 4) + i; + result[idx] = converter(source[idx]); } return result; } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const{ return convert(s); } }; -/// Partial specialization for Array <= Array +/// Partial specialization for Array <= Array template < + typename T, + int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; - using result_type = Array; - using source_type = Array; +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = float_ue8m0_t; + + using result_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & source) { - unsigned out; + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint32_t out_fp16; + uint16_t const& src_packed = reinterpret_cast(source); - asm volatile( - "{ .reg .u32 r4;" - "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" - "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" - "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" - "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" - "}" - : "=r"(out) - : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), - "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); + asm volatile( \ + "{\n" \ + "cvt.rn.bf16x2.ue8m0x2 %0, %1;\n" \ + "}\n" : "=r"(out_fp16): "h"(src_packed)); - return reinterpret_cast(out); + NumericArrayConverter bf2fp32_converter; + auto res0 = bf2fp32_converter(reinterpret_cast &>(out_fp16)); + + result_type out; + out[0] = res0[0]; + out[1] = res0[1]; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif } CUTLASS_HOST_DEVICE @@ -3089,34 +3360,39 @@ struct NumericArrayConverter { } }; -/// Partial specialization for Array <= Array -template < - int N, - FloatRoundStyle Round -> -struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); +/// Partial specialization for Array <= Array +template <> +struct NumericArrayConverter { + using result_element = float_ue8m0_t; + using source_element = float; - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_infinity; CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - NumericArrayConverter convert_vector_; + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint16_t out; + asm volatile( \ + "{\n" \ + "cvt.rp.satfinite.ue8m0x2.f32 %0, %2, %1;\n" \ + "}" \ + : "=h"(out) : "f"(source[0]), "f"(source[1])); + return reinterpret_cast(out); + #else result_type result; - - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); + NumericConverter converter; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); } return result; + #endif } CUTLASS_HOST_DEVICE @@ -3125,97 +3401,1328 @@ struct NumericArrayConverter { } }; -#endif // Conditional guards to enable partial specialization for packed integers - -namespace detail { +/// Partial specialization for Array <= Array +template <> +struct NumericArrayConverter { + using result_element = float_ue8m0_t; + using source_element = float; - /* - A helper class that can vectorize a numeric converter with implementation for several vector widths. + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - The vector widths must be giving in decreasing order or width, and must be a power of 2. + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { - The vector converters must produce identical results to the scalar converters for consistency. - */ - class VectorizedConverter { - private: - // Base case to handle remainder elements as scalars. - template - CUTLASS_DEVICE - static void convert_helper( - typename ArrayConverter::result_type& result, - typename ArrayConverter::source_type const& source) { + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint16_t out; + asm volatile( \ + "{\n" \ + "cvt.rz.satfinite.ue8m0x2.f32 %0, %2, %1;\n" \ + "}" \ + : "=h"(out) : "f"(source[0]), "f"(source[1])); - using ElementRes = typename ArrayConverter::result_type::Element; - using ElementSrc = typename ArrayConverter::source_type::Element; - // If no more converters, handle the remaining elements as scalars. - constexpr int total_elements = ArrayConverter::result_type::kElements; - constexpr int remainder = total_elements - Offset; - static_assert(remainder == (total_elements % ParentWidth), "Unexpected remainder."); + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; - typename ArrayConverter::ScalarConverter scalar_converter; - CUTLASS_PRAGMA_UNROLL + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_ue8m0_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + return NumericArrayConverter{}(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = float_ue4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out_fp16; + uint16_t const& src_packed = reinterpret_cast(source); + + asm volatile( \ + "{\n" \ + "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ + "}\n" : "=r"(out_fp16): "h"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16)); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_ue4m3_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t out; + + asm volatile( \ + "{\n" \ + "cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;\n" \ + "}" \ + : "=h"(out) : "f"(source[0]), "f"(source[1])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + + +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = cutlass::float_e2m1_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t out_fp16[4]; + uint32_t const& src_packed = reinterpret_cast(source); + + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1, byte2, byte3;\n" \ + "mov.b32 {byte0, byte1, byte2, byte3}, %4;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ + "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ + "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) , "=r"(out_fp16[2]), "=r"(out_fp16[3]): "r"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); + float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); + float2 res2 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[2])); + float2 res3 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[3])); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + out[2] = res1.x; + out[3] = res1.y; + out[4] = res2.x; + out[5] = res2.y; + out[6] = res3.x; + out[7] = res3.y; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e2m1_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t tmp; + asm volatile( \ + "{\n" \ + ".reg .b8 byte0;\n" \ + ".reg .b8 byte1;\n" \ + ".reg .b8 byte2;\n" \ + ".reg .b8 byte3;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" \ + "}" \ + : "=r"(tmp) : "f"(source[0]), "f"(source[1])); + + uint8_t out = (tmp & 0xff); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = cutlass::float_e2m1_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + unsigned out; + asm volatile( \ + "{\n" \ + ".reg .b8 byte0;\n" \ + ".reg .b8 byte1;\n" \ + ".reg .b8 byte2;\n" \ + ".reg .b8 byte3;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" \ + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3]), + "f"(source[4]), "f"(source[5]), "f"(source[6]), "f"(source[7])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e2m1_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + unsigned out; + asm volatile( \ + "{\n" \ + ".reg .b8 byte0;\n" \ + ".reg .b8 byte1;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ + "mov.b32 %0, {byte0, byte1, 0, 0};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +/// Conversion is performed with saturation regardless of setting of +/// the `Round` template parameter. +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericConverter destination_converter; + result_type result; + result[0] = destination_converter(source[0]); + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericConverter destination_converter; + result_type result; + result[0] = destination_converter(source[0]); + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +// To convert a FP32 to Int that has less than 32 bits, we need to convert it to int32 first. +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayFP32ToIntConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + static_assert(cutlass::platform::numeric_limits::is_integer, "the dest type has to be int."); + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + // Convert float to int + Array temporary; + + NumericArrayConverter compute_converter; + temporary = compute_converter(source); + + // Convert to int to int8_t + NumericArrayConverter destination_converter; + return destination_converter(temporary); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned out; + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(out) + : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), + "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); + + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned out; + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(out) + : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), + "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); + + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +#endif // Conditional guards to enable partial specialization for packed integers + +namespace detail { + + /* + A helper class that can vectorize a numeric converter with implementation for several vector widths. + + The vector widths must be giving in decreasing order or width, and must be a power of 2. + + The vector converters must produce identical results to the scalar converters for consistency. + */ + class VectorizedConverter { + private: + // Base case to handle remainder elements as scalars. + template + CUTLASS_DEVICE + static void convert_helper( + typename ArrayConverter::result_type& result, + typename ArrayConverter::source_type const& source) { + + using ElementRes = typename ArrayConverter::result_type::Element; + using ElementSrc = typename ArrayConverter::source_type::Element; + // If no more converters, handle the remaining elements as scalars. + constexpr int total_elements = ArrayConverter::result_type::kElements; + constexpr int remainder = total_elements - Offset; + static_assert(remainder == (total_elements % ParentWidth), "Unexpected remainder."); + + typename ArrayConverter::ScalarConverter scalar_converter; + CUTLASS_PRAGMA_UNROLL for (int i = Offset; i < ArrayConverter::result_type::kElements; ++i) { result[i] = scalar_converter(ElementSrc(source[i])); } } - template + template + CUTLASS_DEVICE + static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { + static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); + static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); + static_assert(cutlass::platform::is_same::value, + "ResultVectorArray must have the same type ArrayConverter::result_type"); + static_assert(cutlass::platform::is_same::value, + "SourceVectorArray must have the same type ArrayConverter::result_type"); + static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); + + static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); + + constexpr int vector_width = ResultVectorArray::kElements; + static_assert(ispow2(vector_width), "Vector width must be a power of 2"); + + using ElementRes = typename ArrayConverter::result_type::Element; + using ElementSrc = typename ArrayConverter::source_type::Element; + + constexpr int vector_bits_res = vector_width * cutlass::sizeof_bits::value; + constexpr int vector_bits_src = vector_width * cutlass::sizeof_bits::value; + + static_assert(vector_bits_res % 8 == 0, "Result vector type must be byte addressed."); + static_assert(vector_bits_src % 8 == 0, "Source vector type must be byte addressed."); + + constexpr int vector_offset = Offset / vector_width; + ResultVectorArray* packed_result_vec = reinterpret_cast(&result) + vector_offset; + SourceVectorArray const* packed_source_vec = reinterpret_cast(&source) + vector_offset; + + // Convert the remaining elements as vectors. + constexpr int total_elements = ArrayConverter::result_type::kElements; + constexpr int groups_of_vec = (total_elements - Offset) / vector_width; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < groups_of_vec; ++i) { + packed_result_vec[i] = ArrayConverter::template packed_convert(packed_source_vec[i]); + } + + constexpr int new_offset = Offset + vector_width * groups_of_vec; + // Recurse to handle other vector converters, or the scalar base case. + convert_helper(result, source); + } + + public: + /* + A method to convert vectors of elements using the packed_convert method of the converter. + + Converters using this class must implement packed convert and support 1 or more vector conversions. + */ + template CUTLASS_DEVICE - static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { - static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); - static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); - static_assert(cutlass::platform::is_same::value, - "ResultVectorArray must have the same type ArrayConverter::result_type"); - static_assert(cutlass::platform::is_same::value, - "SourceVectorArray must have the same type ArrayConverter::result_type"); - static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); + static void convert(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { + convert_helper<0, 0, ArrayConverter, ResultVectorArray, SourceVectorArray, OtherVectorArrays...>(result, source); + } + }; +} - static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); + + // Hold output FP8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 2; + + src_reg &= 0x333333333333; // s14s12s10s8s6s4s2s0 + src_reg_shifted &= 0x333333333333; // s15s13s11s9s7s5s3s1 + + // [0, 1, -2, -1] encoded as FP8 + static constexpr uint32_t E4M3_LUT = 0xB8C03800; + + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { + // This uses a look up table to convert packed int2s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 f8_6420, f8_7531;\n" + " prmt.b32 f8_6420, %4, 0, %2;\n" + " prmt.b32 f8_7531, %4, 0, %3;\n" + " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 + " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 + "}\n" + : "=r"(r[ii]), "=r"(r[ii+1]) + : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); + + // Hold output FP8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 2; + + src_reg &= 0x333333333333; // u14u12u10u8u6u4u2u0 + src_reg_shifted &= 0x333333333333; // u15u13u11u9u7u5u3u1 + + // [0, 1, 2, 3] encoded as FP8 + static constexpr uint32_t E4M3_LUT = 0x44403800; + + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { + // This uses a look up table to convert packed uint2s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 f8_6420, f8_7531;\n" + " prmt.b32 f8_6420, %4, 0, %2;\n" + " prmt.b32 f8_7531, %4, 0, %3;\n" + " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 + " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 + "}\n" + : "=r"(r[ii]), "=r"(r[ii+1]) + : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); + + // Hold output FP8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 2; + + src_reg &= 0x333333333333; // s14s12s10s8s6s4s2s0 + src_reg_shifted &= 0x333333333333; // s15s13s11s9s7s5s3s1 + + // [0, 1, -2, -1] encoded as FP8 + static constexpr uint32_t E4M3_LUT = 0xBCC03C00; + + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { + // This uses a look up table to convert packed int2s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 f8_6420, f8_7531;\n" + " prmt.b32 f8_6420, %4, 0, %2;\n" + " prmt.b32 f8_7531, %4, 0, %3;\n" + " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 + " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 + "}\n" + : "=r"(r[ii]), "=r"(r[ii+1]) + : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } - constexpr int vector_width = ResultVectorArray::kElements; - static_assert(ispow2(vector_width), "Vector width must be a power of 2"); + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { - using ElementRes = typename ArrayConverter::result_type::Element; - using ElementSrc = typename ArrayConverter::source_type::Element; + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 8 or 16 to use private convert dispatch."); - constexpr int vector_bits_res = vector_width * cutlass::sizeof_bits::value; - constexpr int vector_bits_src = vector_width * cutlass::sizeof_bits::value; + // Hold output FP8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray; + RegArray r; - static_assert(vector_bits_res % 8 == 0, "Result vector type must be byte addressed."); - static_assert(vector_bits_src % 8 == 0, "Source vector type must be byte addressed."); + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 2; - constexpr int vector_offset = Offset / vector_width; - ResultVectorArray* packed_result_vec = reinterpret_cast(&result) + vector_offset; - SourceVectorArray const* packed_source_vec = reinterpret_cast(&source) + vector_offset; + src_reg &= 0x333333333333; // u14u12u10u8u6u4u2u0 + src_reg_shifted &= 0x333333333333; // u15u13u11u9u7u5u3u1 - // Convert the remaining elements as vectors. - constexpr int total_elements = ArrayConverter::result_type::kElements; - constexpr int groups_of_vec = (total_elements - Offset) / vector_width; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < groups_of_vec; ++i) { - packed_result_vec[i] = ArrayConverter::template packed_convert(packed_source_vec[i]); - } + // [0, 1, 2, 3] encoded as FP8 + static constexpr uint32_t E4M3_LUT = 0x42403C00; - constexpr int new_offset = Offset + vector_width * groups_of_vec; - // Recurse to handle other vector converters, or the scalar base case. - convert_helper(result, source); + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ii += 2, src_reg >>= 16, src_reg_shifted >>= 16) { + // This uses a look up table to convert packed uint2s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 f8_6420, f8_7531;\n" + " prmt.b32 f8_6420, %4, 0, %2;\n" + " prmt.b32 f8_7531, %4, 0, %3;\n" + " prmt.b32 %0, f8_6420, f8_7531, 0x5140;\n" // 3210 + " prmt.b32 %1, f8_6420, f8_7531, 0x7362;\n" // 7654 + "}\n" + : "=r"(r[ii]), "=r"(r[ii+1]) + : "r"(src_reg), "r"(src_reg_shifted), "n"(E4M3_LUT)); } - public: - /* - A method to convert vectors of elements using the packed_convert method of the converter. + return reinterpret_cast(r); + } - Converters using this class must implement packed convert and support 1 or more vector conversions. - */ - template - CUTLASS_DEVICE - static void convert(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { - convert_helper<0, 0, ArrayConverter, ResultVectorArray, SourceVectorArray, OtherVectorArrays...>(result, source); - } - }; -} + friend class detail::VectorizedConverter; -///////////////////////////////////////////////////////////////////////////////////////////////// +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; /// Partial specialization for Array <= Array template < @@ -3357,9 +4864,9 @@ struct NumericArrayConverter static constexpr uint32_t POS_E4M3s_REG1 = 0x44403800; // [4, 5, 6, 7] encoded as FP8 static constexpr uint32_t POS_E4M3s_REG2 = 0x4E4C4A48; - // [-1, -2, -3, -4] encoded as FP8 + // [-8, -7, -6, -5] encoded as FP8 static constexpr uint32_t NEG_E4M3s_REG1 = 0xCACCCED0; - // [-5, -6, -7, -7] encoded as FP8 + // [-4, -3, -2, -1] encoded as FP8 static constexpr uint32_t NEG_E4M3s_REG2 = 0xB8C0C4C8; @@ -3407,6 +4914,220 @@ struct NumericArrayConverter } }; +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses a lookup table to converts i4 -> e5m2. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); + + // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. + cutlass::AlignedArray r; + + // View the input as reg + uint32_t reg = to_reg(source); + + // Determines if to get from the signed or unsigned candidates + uint32_t sign = (reg & 0x88888888) >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = (reg & 0x77777777); + + // Signed is OR'd with 0x32103210 to find the correct value in the LUT + const uint32_t final_prmt_base = 0x32103210; + + // [0, 1, 2, 3] encoded as FP8 + static constexpr uint32_t POS_E5M2s_REG1 = 0x42403C00; + // [4, 5, 6, 7] encoded as FP8 + static constexpr uint32_t POS_E5M2s_REG2 = 0x47464544; + // [-8, -7, -6, -5] encoded as FP8 + static constexpr uint32_t NEG_E5M2s_REG1 = 0xC5C6C7C8; + // [-4, -3, -2, -1] encoded as FP8 + static constexpr uint32_t NEG_E5M2s_REG2 = 0xBCC0C2C4; + + + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { + uint32_t final_prmt_idx = final_prmt_base | sign; + + // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 pos_f8s, neg_f8s;\n" + " prmt.b32 pos_f8s, %1, %2, %5;\n" + " prmt.b32 neg_f8s, %3, %4, %5;\n" + " prmt.b32 %0, pos_f8s, neg_f8s, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(POS_E5M2s_REG1), "n"(POS_E5M2s_REG2), "n"(NEG_E5M2s_REG1), "n"(NEG_E5M2s_REG2), + "r"(lut_idx), "r"(final_prmt_idx)); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses a lookup table to converts u4 -> e4m3. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); + + // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. + cutlass::AlignedArray r; + + // View the input as reg + uint32_t reg = to_reg(source); + + // Determines if to get from the [0-7] or [8-15] candidates + uint32_t sign = (reg & 0x88888888) >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = (reg & 0x77777777); + + // Signed is OR'd with 0x32103210 to find the correct value in the LUT + const uint32_t final_prmt_base = 0x32103210; + + // [0, 1, 2, 3] encoded as FP8 + static constexpr uint32_t E4M3s_REG1 = 0x44403800; + // [4, 5, 6, 7] encoded as FP8 + static constexpr uint32_t E4M3s_REG2 = 0x4E4C4A48; + // [8, 9, 10, 11] encoded as FP8 + static constexpr uint32_t E4M3s_REG3 = 0x53525150; + // [12, 13, 14, 15] encoded as FP8 + static constexpr uint32_t E4M3s_REG4 = 0x57565554; + + + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { + uint32_t final_prmt_idx = final_prmt_base | sign; + + // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 f8s_1, f8s_2;\n" + " prmt.b32 f8s_1, %1, %2, %5;\n" + " prmt.b32 f8s_2, %3, %4, %5;\n" + " prmt.b32 %0, f8s_1, f8s_2, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(E4M3s_REG1), "n"(E4M3s_REG2), "n"(E4M3s_REG3), "n"(E4M3s_REG4), + "r"(lut_idx), "r"(final_prmt_idx)); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + /// Partial specialization for Array <= Array template struct NumericArrayConverter { diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index e5fa5f9cbd..b0c616a75a 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -82,5 +82,7 @@ struct get_unpacked_element_type { #include "cutlass/tfloat32.h" #include "cutlass/float8.h" #include "cutlass/uint128.h" +#include "cutlass/exmy_base.h" +#include "cutlass/float_subbyte.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/pipeline.hpp b/include/cutlass/pipeline/pipeline.hpp index 040ecee3c5..e9cf66a794 100644 --- a/include/cutlass/pipeline/pipeline.hpp +++ b/include/cutlass/pipeline/pipeline.hpp @@ -33,4 +33,6 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/pipeline/sm100_pipeline.hpp" + //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm100_pipeline.hpp b/include/cutlass/pipeline/sm100_pipeline.hpp new file mode 100644 index 0000000000..e5ac47a844 --- /dev/null +++ b/include/cutlass/pipeline/sm100_pipeline.hpp @@ -0,0 +1,918 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +// + +// + +#include "cute/numeric/integral_constant.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "sm90_pipeline.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +using namespace cute; + +enum class McastDirection { + kRow, + kCol, + kRowCol +}; +namespace detail { + +template +CUTLASS_DEVICE +uint16_t calculate_multicast_mask(ClusterShape cluster_shape, AtomThrShape_MNK atom_thr_shape, dim3 block_id_in_cluster) { + auto is_participant = [&](auto x, auto y) { + if constexpr (McastDir == McastDirection::kRowCol) { + return (x/size<0>(atom_thr_shape) == block_id_in_cluster.x/size<0>(atom_thr_shape) || // is same MMA cluster col + y/size<1>(atom_thr_shape) == block_id_in_cluster.y/size<1>(atom_thr_shape)); // is same MMA cluster row + } + else if constexpr (McastDir == McastDirection::kRow) { + return (x/size<0>(atom_thr_shape) == block_id_in_cluster.x/size<0>(atom_thr_shape)); // is same MMA cluster row + } + else { // (McastDir == McastDirection::kCol) + return (y/size<1>(atom_thr_shape) == block_id_in_cluster.y/size<1>(atom_thr_shape)); // is same MMA cluster col + } + }; + + uint16_t block_id_mask = 0; + auto cluster_layout = make_layout(cluster_shape); + // When MMA_2x1SM instructions are used, the definition of "same row" changes. + // With MMA_2x1SM, we need to send the notification for MMA completion to all + // 2x1 threadblocks of the cluster. Below is a 4x4 example where R are the threadblocks + // that receives the release for A/B buffers that threadblock (0,0) uses. + // Row&Col Row Col + // RRRR RRRR Cxxx + // RRRR RRRR Cxxx + // Rxxx xxxx Cxxx + // Rxxx xxxx Cxxx + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x(cluster_shape); x++) { + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y(cluster_shape); y++) { + if (is_participant(x,y)) { + block_id_mask |= (1 << cluster_layout(x,y, Int<0>{})); + } + } + } + return block_id_mask; +} + +template +CUTLASS_DEVICE +uint16_t calculate_umma_peer_mask(ClusterShape cluster_shape, AtomThrShape_MNK atom_thr_shape, dim3 block_id_in_cluster) { + uint16_t tmem_sync_mask = 0; + auto cluster_layout = make_layout(cluster_shape); + int block_id_in_cluster_x = (block_id_in_cluster.x / size<0>(AtomThrShape_MNK{})) * size<0>(AtomThrShape_MNK{}) ; + int block_id_in_cluster_y = (block_id_in_cluster.y / size<1>(AtomThrShape_MNK{})) * size<1>(AtomThrShape_MNK{}) ; + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < size<0>(AtomThrShape_MNK{}); x++) { + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < size<1>(AtomThrShape_MNK{}); y++) { + tmem_sync_mask |= (1 << cluster_layout(block_id_in_cluster_x + x, block_id_in_cluster_y + y, Int<0>{})); + } + } + + return tmem_sync_mask; +} +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA (producer) Async Pipeline class for Blackwell UMMA +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template > +class PipelineUmmaAsync { +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; +private: + using Impl = PipelineAsync; +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate producer mask + if (params_.role == ThreadCategory::Producer) { + // The leader threadblock executing the MMA_2x1SM instruction will signal its peer + // threadblock when it is done with MMA operations. tmem_sync_mask encodes the + // position of peer SMs in the cluster + tmem_sync_mask_ = detail::calculate_umma_peer_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + // Constructor by default initializes barriers and calculates masks. + // These operations can be explicity deferred by specifying InitBarriers and InitMasks. + // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. + template + CUTLASS_DEVICE + PipelineUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, InitBarriers{}) + , params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return impl_.producer_try_acquire(state, skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index()); + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + impl_.producer_tail(state); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state.index()); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_try_wait(state, skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + detail::pipeline_check_is_consumer(params_.role); + if constexpr (is_2sm_mma) { + consumer_release_2x1SM(state.index()); + } else { + impl_.consumer_release(state); + } + } + +private: + Impl impl_; + Params params_; + FullBarrier* full_barrier_ptr_ = nullptr; + EmptyBarrier* empty_barrier_ptr_ = nullptr; + uint16_t tmem_sync_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + CUTLASS_DEVICE + void producer_commit(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&full_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, tmem_sync_mask_); + } + else { + cutlass::arch::umma_arrive(smem_ptr); + } + } + + CUTLASS_DEVICE + void consumer_release_2x1SM(uint32_t stage) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + cutlass::arch::umma_arrive_2x1SM_sm0(smem_ptr); + static_assert(is_2sm_mma, "ERROR : AtomThrShape_MNK does not correspond to a 2SM MMMA"); + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA (consumer) Async Pipeline classes for Blackwell UMMA +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Producer-consumer pipeline implementation +// for UMMA producer. In this case, UMMA barrier arrives are used +// by producer_commit. Use case, accumulator generation as +// the result of MMA instructions. +template < + int Stages_, + class ClusterShape = Shape, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaUmmaAsync { +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; +private: + using Impl = PipelineTmaAsync; +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas + cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + else { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + // Constructor by default initializes barriers and calculates masks. + // These operations can be explicity deferred by specifying InitBarriers and InitMasks. + // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. + template + CUTLASS_DEVICE + PipelineTmaUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + // !!!!!! I DONT LIKE THIS MCAST BASED CONSTRUCTOR SPECIALIZATION. THIS VARIABLE NEVER CHANGES AT RUNTIME. + template + CUTLASS_DEVICE + PipelineTmaUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + dim3 block_id = block_id_in_cluster(); + + int warp_idx = canonical_warp_idx_sync(); + auto atom_thr_shape = AtomThrShape_MNK{}; + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return impl_.producer_try_acquire(state, skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + // NOP for TMA based mainloop + CUTLASS_DEVICE + void producer_commit(PipelineState state, uint32_t bytes) { + impl_.producer_commit(state, bytes); + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + impl_.producer_tail(state); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_try_wait(state, skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index(), false); + } + +private: + Impl impl_; + Params params_; + EmptyBarrier *empty_barrier_ptr_; + FullBarrier *full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } + else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } + else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; + +// Producer-consumer pipeline implementation +// for UMMA consumer. In this case, UMMA barrier arrives are +// used by consumer_release. +template > +class PipelineUmmaConsumerAsync { +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; +private: + using Impl = PipelineAsync; +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + template + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + // The leader threadblock executing the MMA_2x1SM instruction will signal its peer + // threadblock when it is done with MMA operations. tmem_sync_mask encodes the + // position of peer SMs in the cluster + tmem_sync_mask_ = detail::calculate_umma_peer_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + // Constructor by default initializes barriers and calculates masks. + // These operations can be explicity deferred by specifying InitBarriers and InitMasks. + // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. + template + CUTLASS_DEVICE + PipelineUmmaConsumerAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, InitBarriers{}) + , params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + //////////////////// + // Producer APIs + //////////////////// + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return impl_.producer_try_acquire(state, skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + template + CUTLASS_DEVICE + void producer_commit(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { + cute::forward(user_defined_arrive_op)(producer_get_barrier(state)); + producer_commit(state); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + if constexpr (is_2sm_mma) { + producer_commit_2x1SM(state.index()); + } else { + impl_.producer_commit(state); + } + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + impl_.producer_tail(state); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state.index()); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_try_wait(state, skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + if (barrier_token == BarrierStatus::WaitAgain) { + impl_.consumer_wait(state); + } + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +private: + Impl impl_; + Params params_; + FullBarrier* full_barrier_ptr_ = nullptr; + EmptyBarrier* empty_barrier_ptr_ = nullptr; + uint16_t tmem_sync_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + CUTLASS_DEVICE + void producer_commit_2x1SM(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&full_barrier_ptr_[stage]); + cutlass::arch::umma_arrive_2x1SM_sm0(smem_ptr); + static_assert(is_2sm_mma, "ERROR : AtomThrShape_MNK does not correspond to a 2SM MMMA"); + } + + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, tmem_sync_mask_); + } + else { + cutlass::arch::umma_arrive(smem_ptr); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// CLC Async Pipeline class for Blackwell UMMA +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace PipelineDetail { + +template +using PipelineCLCFetchAsyncPipelineState = cutlass::PipelineState; + +template +struct PipelineCLCFetchAsyncSharedStorage { + using FullBarrier = cutlass::arch::ClusterTransactionBarrier; + using EmptyBarrier = cutlass::arch::ClusterBarrier; + + FullBarrier full_barrier_[static_cast(Stages_)]; + EmptyBarrier empty_barrier_[static_cast(Stages_)]; +}; + +} // namespace PipelineDetail + +template > +class PipelineCLCFetchAsync { + +public: + static constexpr uint32_t Stages = Stages_; + using PipelineState = PipelineDetail::PipelineCLCFetchAsyncPipelineState; + using SharedStorage = PipelineDetail::PipelineCLCFetchAsyncSharedStorage; + using FullBarrier = typename SharedStorage::FullBarrier; + using EmptyBarrier = typename SharedStorage::EmptyBarrier; + + enum class ThreadCategory { + NonParticipant, + Producer, + Consumer, + ProducerConsumer + }; + + struct Params { + uint32_t transaction_bytes = 0; + ThreadCategory role = ThreadCategory::NonParticipant; + uint32_t is_leader = 0; + uint32_t num_consumers = 0; + uint32_t producer_blockid = 0; + uint32_t producer_arv_count = 0; + uint32_t consumer_arv_count = 0; + int initializing_warp = 0; + }; + + // Constructor + CUTLASS_DEVICE + PipelineCLCFetchAsync(SharedStorage& storage, Params const& params) : + params_(params), + full_barrier_ptr_(&storage.full_barrier_[0]), + empty_barrier_ptr_(&storage.empty_barrier_[0]) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); + } + cutlass::arch::fence_barrier_init(); + + cluster_size_ = []() { auto cs = cute::cluster_shape(); return cs.x * cs.y; }(); + } + + // Constructor + CUTLASS_DEVICE + PipelineCLCFetchAsync(SharedStorage& storage, Params const& params, ClusterShape cluster_shape) + : params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); + } + cutlass::arch::fence_barrier_init(); + + cluster_size_ = cute::size<0>(cluster_shape) + * cute::size<1>(cluster_shape) + * cute::size<2>(cluster_shape); + } + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return producer_try_acquire(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire(state.index(), state.phase(), barrier_token); + } + + // Manual completion of transaction count + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index(), state.phase()); + } + + // Prevents early exit of producer blocks in Cluster. + // Does NOT reset transaction bytes. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + detail::pipeline_check_is_producer(params_.role); + for (int count = 0; count < Stages; ++count) { + bool done = empty_barrier_ptr_[state.index()].test_wait(state.phase()); + if (!done) { + empty_barrier_ptr_[state.index()].wait(state.phase()); + } + ++state; + } + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return consumer_try_wait(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + consumer_wait(state.index(), state.phase(), barrier_token); + } + + // Consumer signalling Producer of completion + // Notifies the producer block in the Cluster + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + + CUTLASS_DEVICE + uint32_t producer_get_barrier(PipelineState state) { + return cute::cast_smem_ptr_to_uint(reinterpret_cast(&full_barrier_ptr_[state.index()])); + } + +private: + FullBarrier *full_barrier_ptr_ = nullptr; + EmptyBarrier *empty_barrier_ptr_ = nullptr; + Params params_; + int lane_idx_ = canonical_lane_idx(); + int cluster_size_; + + CUTLASS_DEVICE + ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + bool barrier_stat = empty_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_stat)}; + } + + CUTLASS_DEVICE + void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); + // 1. Wait for empty barrier to be ready + // 2. Set the transaction bytes set to occur on the Full barrier for all blocks + if (barrier_token == BarrierStatus::WaitAgain) { + empty_barrier_ptr_[stage].wait(phase); + } + + full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes, lane_idx_, uint32_t(lane_idx_ < cluster_size_)); + } + + CUTLASS_DEVICE + void producer_commit(uint32_t stage, uint32_t phase) { + int cluster_size_ = []() { auto cs = cute::cluster_shape(); return cs.x * cs.y; }(); + full_barrier_ptr_[stage].complete_transaction(lane_idx_, params_.transaction_bytes, uint32_t(lane_idx_ < cluster_size_)); + } + + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + bool barrier_stat = full_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_stat)}; + } + + // Wait for producer to commit transactions + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); + if (barrier_token == BarrierStatus::WaitAgain) { + full_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void consumer_release(uint32_t stage) { + detail::pipeline_check_is_consumer(params_.role); + empty_barrier_ptr_[stage].arrive(params_.producer_blockid); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Empty Pipeline class +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +class PipelineEmpty { +public: + static constexpr uint32_t Stages = 0; + using PipelineState = cutlass::PipelineState<0>; + struct Params {}; + struct SharedStorage {}; + + // Constructor + CUTLASS_DEVICE + PipelineEmpty(SharedStorage&& storage, Params const& params) {} + + // Constructor with throwaway ClusterShape + template > + CUTLASS_DEVICE + PipelineEmpty(SharedStorage&& storage, Params const& params, ClusterShape) {} + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 58f49c36f7..6b766fc246 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -295,6 +295,7 @@ class PipelineTmaAsync { uint32_t is_leader = 0; uint32_t num_consumers = 0; // Number of consumer threads uint32_t num_producers = 1; // Number of producer threads + int initializing_warp = 0; }; template @@ -304,6 +305,7 @@ class PipelineTmaAsync { init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { int warp_idx = canonical_warp_idx_sync(); bool is_initializing_warp = (warp_idx == 0); + is_initializing_warp = (warp_idx == params.initializing_warp); if (is_initializing_warp) { // Barrier FULL and EMPTY init uint32_t const producer_arv_cnt = params.num_producers; @@ -750,6 +752,7 @@ class PipelineTransactionAsync { uint32_t producer_arv_count = 1; uint32_t consumer_arv_count = 1; uint32_t dst_blockid = cute::block_rank_in_cluster(); + int initializing_warp = 0; }; static @@ -760,6 +763,8 @@ class PipelineTransactionAsync { EmptyBarrier *empty_barrier_ptr = storage.empty_barrier_.data(); int warp_idx = canonical_warp_idx_sync(); bool is_initializing_warp = (warp_idx == 0); + is_initializing_warp = (warp_idx == params.initializing_warp); + if (is_initializing_warp) { // Barrier FULL and EMPTY init cutlass::arch::detail::initialize_barrier_array_pair_aligned( @@ -989,6 +994,7 @@ class PipelineAsync { uint32_t producer_arv_count = 1; uint32_t consumer_arv_count = 1; uint32_t dst_blockid = cute::block_rank_in_cluster(); + int initializing_warp = 0; }; static @@ -997,6 +1003,7 @@ class PipelineAsync { init_barriers(SharedStorage& storage, Params params) { int warp_idx = canonical_warp_idx_sync(); bool is_initializing_warp = (warp_idx == 0); + is_initializing_warp = (warp_idx == params.initializing_warp); if (is_initializing_warp) { // Barrier FULL and EMPTY init cutlass::arch::detail::initialize_barrier_array_pair_aligned( @@ -1222,6 +1229,7 @@ class OrderedSequenceBarrier { struct Params { uint32_t group_id; uint32_t group_size; + int initializing_warp = 0; }; private: @@ -1247,6 +1255,19 @@ class OrderedSequenceBarrier { barrier_ptr_(&storage.barrier_[0][0]), // Group 0 - starts with an opposite phase stage_({0, params.group_id == 0, 0}) { + +#if (__CUDA_ARCH__ >= 1000) + int warp_idx = canonical_warp_idx_sync(); + + // Barrier FULL, EMPTY init + if (warp_idx == params.initializing_warp) { + int arv_cnt = params.group_size; + constexpr int Stages = Depth * Length; + cutlass::arch::detail::initialize_barrier_array_aligned( + barrier_ptr_, arv_cnt); + } +#else + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); @@ -1259,6 +1280,7 @@ class OrderedSequenceBarrier { } } } +#endif cutlass::arch::fence_barrier_init(); } diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 779c155227..65b77904d1 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -270,6 +270,36 @@ bool relatively_equal(complex a, complex b, complex epsilon, complex +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e2m3_t a, float_e2m3_t b, float_e2m3_t epsilon, float_e2m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e3m2_t a, float_e3m2_t b, float_e3m2_t epsilon, float_e3m2_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e2m1_t a, float_e2m1_t b, float_e2m1_t epsilon, float_e2m1_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_ue8m0_t a, float_ue8m0_t b, float_ue8m0_t epsilon, float_ue8m0_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_ue4m3_t a, float_ue4m3_t b, float_ue4m3_t epsilon, float_ue4m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 984d39d186..1e2b5de94d 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -35,7 +35,7 @@ #include #define CUTLASS_MAJOR 3 -#define CUTLASS_MINOR 7 +#define CUTLASS_MINOR 8 #define CUTLASS_PATCH 0 #ifdef CUTLASS_VERSIONS_GENERATED diff --git a/media/docs/blackwell_functionality.md b/media/docs/blackwell_functionality.md new file mode 100644 index 0000000000..a7c6169f4f --- /dev/null +++ b/media/docs/blackwell_functionality.md @@ -0,0 +1,584 @@ +# Blackwell SM100 GEMMs + +[**TLDR; jump to block scaled GEMM example**](#detailed_blockscale_example) + +Blackwell SM100 introduces `tcgen05.mma` instructions. `tcgen05.mma` instructions support all legacy types (`tfloat32_t`, `half_t`, `bfloat16_t`, `int8_t`, `uint8_t`) and +the new 4, 6, and 8-bits floating point datatypes with and without scale factors. +This document explains the new `tcgen05.mma` instructions supported by CUTLASS and how one can leverage CUTLASS to create +efficient SM100 GEMM kernels targeting these new mma instructions. + +Blackwell SM100 has 7 new `tcgen05.mma` instructions. These instructions are 2x to 4x faster then Hopper Architecture's WGMMA instructions. + +| Ptx Instruction | Throughput | Notes | +|----------------------------------------------------------------------------------|----------------------------|-------| +|tcgen05.mma.cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts | +|tcgen05.mma.cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts | +|tcgen05.mma.cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts | +|tcgen05.mma.cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts | +|tcgen05.mma.cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts | +|tcgen05.mma.cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts | +|tcgen05.mma.cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts | + +For more detailed information see [`tcgen05.mma` PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensorcore-5th-generation-family-instructions). + +## New in Blackwell SM100 + +### Block Scaled GEMMs + +Instructions with `kind` modifiers `mxf8f6f4`, `mxf4`, and `nvf4mxf4` perform matrix multiplication operations with scale +factors of the form $D = C +( A \times SFA) * (B \times SFB)$. Scale factors are applied to GEMM-K dimension such that +every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor. For example, an $M\times K$, +$A$ matrix has an associated $M \times \lceil K/32 \rceil$ SFA matrix; and an $N\times K$ $B$, matrix has an associated +$N \times \lceil K/32 \rceil$ SFB matrix. For block scaled GEMMs, an entry of output D matrix is +$D_{ij} = C_{ij} + \sum_{k} (A_{i,k} \times SFA_{i,k/SV}) \times (B_{j,k}\times SFB_{j,k/SV})$, in index notation, we SV is the scale factor vector size (16 or 32). +Further details can be found in +[PTX documentation on block scaling](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-block-scaling). + +### Blackwell Narrow Precision Data Types + +Narrow-precision `tcgen05.mma` instructions can operate on several 4, 6, and 8-bit data types. Blackwell MMAs can operate +on five different 8-bit floating point values, of which only two (`float_ue8m0_t` and `float_ue4m3_t`) can be used as scale factor data types. +There are two 6-bit floating point types and one 4-bit floating point data type. +See [PTX documentation for narrow precision data types](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats) for details. + +**Blackwell Narrow Precision Data Types** +| Data Type | Exponent Bits | Mantissa Bits | Signed | Bit Size | +|-------------------|---------------|---------------|--------|----------| +| float_e4m3_t |4 |3 | Yes | 8 | +| float_e5m2_t |5 |2 | Yes | 8 | +| float_e2m3_t |2 |3 | Yes | 6 | +| float_e3m2_t |3 |2 | Yes | 6 | +| float_e2m1_t |2 |1 | Yes | 4 | +| float_ue8m0_t[^1] |8 |0 | No | 8 | +| float_ue4m3_t[^1] |4 |3 | No | 8 | + +[^1]: Only valid as scale factor data types. + +Block scaled MMAs use `mx` and `nv` types which are a pair of float8_t, float6_t, float4_t with 2 of the scale factor data types with a predetermined scale factor vector size. `mx` types follow OCP specification (see [OCP Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)). The following types provided by CUTLASS can be used as inputs to collective builders to generate the block scaled kernels: + +**Blackwell Block Scaled Narrow Precision Data Types** +| Mx/Nv Data Type |Scale Factor Type | SF Vector Size | OCP Compliant | +|----------------------------|------------------|----------------|---------------| +| mx_float8_t\ |float_ue8m0_t |32 | Yes | +| mx_float6_t\ |float_ue8m0_t |32 | Yes | +| mx_float4_t |float_ue8m0_t |32 | Yes | +| nv_float4_t |float_ue4m3_t |16 | No | + +## Layouts, Tensor Alignment Requirements to Target `tcgen05.mma` Instructions + +Tables below list valid data type, and AB layout combinations. Note that the alignment is reported as number of elements. A and B matrix layouts are +represented with T and N. T represents row-major layouts, and N represents column-major layouts. For instance, TN is +row-major A matrix with column-major B matrix. + +For legacy types (`tf32`, `f16`, `bf16`, `i8` and `u8`) alignment requirements for A and B matrices are the same as in Hopper. +All four layouts (TT, NN, NT, TT) are supported for all legacy data types. + +**Table 1: Valid Data Type, Alignment, and Layout Combinations For MMAs with Legacy Types** +| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | +|-------------------------------|------------|------------|----------------|-------------|-------------|-------------------------|-----------| +|1 | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | | +|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| +|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| +|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| +|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| + +For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every `tcgen05.mma` instructions. +Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision +tensors from global memory to shared memory +(see [PTX doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-tensor-copy-restrictions) for details). + +Below tables list valid layout, and alignment values for each A and B data type combination and their target `tcgen05.mma` +instructions supported by CUTLASS. + +**Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling** +| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | +|-------------------------------|----------|----------|----------------|-------------|-------------|-------------------------|-----------| +|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | +|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | +|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | +|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | +|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](../../test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)| + + +**Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs** +| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test| +|-------------------------|-------------|-------------|----------------|-------------|-------------|-------------------------|------| +|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)| +|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)| +|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)| +|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)| +|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)| +|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)| +|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)| +|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)| +|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)| +|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)| +|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)
[NT unit tests](../../test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)| + +## MMA tile shapes supported + +The alignment restrictions also limit the options for Mma Tile Shapes. Tables below list the supported/valid `MmaTileShape`, +Layout, and Dispatch Policy combinations for each row of [Table 1](#legacy_gemm_table), [Table 2](#non_bs_gemm_table), and [Table 3](#bs_gemm_table). + +**Table 4: Valid Tile Shapes and Dispatch Policies for lagacy types (All rows of Table 1)** +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|------------------|----|----|----|----|------------------------------------| +| 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + +**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6 of Table 2)** + +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|----------------|----|----|----|----|------------------------------------| +| 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + +**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8 of Table 2)** + +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|----------------|----|----|----|----|------------------------------------| +| 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + +**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7 of Table 2)** + +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|----------------|----|----|----|----|------------------------------------| +| 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + +**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9 of Table 2)** + +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|----------------|----|----|----|----|------------------------------------| +| 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + + +**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 of Table 3)** +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|---------------|----|----|----|----|----------------------------------------| +| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | +| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | +| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | +| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | +| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | +| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | + +**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 of Table 3)** +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|---------------|----|----|----|----|----------------------------------------| +| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | +| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | +| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | +| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | +| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | +| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | + +**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 of Table 3)** +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|---------------|----|----|----|----|--------------------------------------------| +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | + +**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10 of Table 3)** +| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|--------|---------------|----|----|----|----|--------------------------------------------| +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | + +**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11 of Table 3)** +| 1/2 SM | Mma Tile Shape | TN| TT | NT | NN | Dispatch Policy | +|--------|---------------|----|----|----|----|--------------------------------------------| +| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | + +## Epilogue config supported + +**Table 14: Epilogue Dispatch Policy** +| 1/2 SM | Epilogue Dispatch Policy | +|--------|------------------------------------------| +| 1SM | cutlass::epilogue::TmaWarpSpecialized1Sm | +| 2SM | cutlass::epilogue::TmaWarpSpecialized2Sm | + +**Table 15: Epilogue PerSmTileShape_MNK** +| 1/2 SM | MMA tile Shape | PerSmTileShape_MNK | +|--------|--------------------------|-------------------------| +| 1SM | 64x64xMMA_TileShape_K | 64x64xMMA_TileShape_K | +| 1SM | 64x128xMMA_TileShape_K | 64x128xMMA_TileShape_K | +| 1SM | 64x192xMMA_TileShape_K | 64x192xMMA_TileShape_K | +| 1SM | 64x256xMMA_TileShape_K | 64x256xMMA_TileShape_K | +| 1SM | 128x64xMMA_TileShape_K | 128x64xMMA_TileShape_K | +| 1SM | 128x128xMMA_TileShape_K | 128x128xMMA_TileShape_K | +| 1SM | 128x192xMMA_TileShape_K | 128x192xMMA_TileShape_K | +| 1SM | 128x256xMMA_TileShape_K | 128x256xMMA_TileShape_K | +| 2SM | 128x64xMMA_TileShape_K | 64x64xMMA_TileShape_K | +| 2SM | 128x128xMMA_TileShape_K | 64x128xMMA_TileShape_K | +| 2SM | 128x192xMMA_TileShape_K | 64x192xMMA_TileShape_K | +| 2SM | 128x256xMMA_TileShape_K | 64x256xMMA_TileShape_K | +| 2SM | 256x64xMMA_TileShape_K | 128x64xMMA_TileShape_K | +| 2SM | 256x128xMMA_TileShape_K | 128x128xMMA_TileShape_K | +| 2SM | 256x192xMMA_TileShape_K | 128x192xMMA_TileShape_K | +| 2SM | 256x256xMMA_TileShape_K | 128x256xMMA_TileShape_K | + +MMA_TileShape_K is is generally 4 * MMA-Instruction-K. It depends on the config we defined in MMA tile shapes supported section. + +### Auto Kernel Dispatch Policies + +In addition to direct dispatch policies listed above, the user can also use auto policies for both non-block scaled narrow-precision +GEMMs, and block scaled narrow-precision GEMMs. + +CUTLASS will do its best to find the most efficient kernel for given parameters, however, the preferred method for building +these kernels is to use direct kernel dispatch policies shown in the above tables. + +* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma`. +* `KernelTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically. +* `KernelTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically. + +Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueScheduleAuto`. + +## Building a Block Scaled Kernel + +For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-gemm-kernel). An example dense GEMM can be found: +1. [Blackwell FP16 GEMM example](../../examples/70_blackwell_gemm/). + +Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface +(as described in [CUTLASS 3.0 GEMM API](gemm_api_3x.md#cutlass-30-gemm-api)). However, special attention needs to be given to +A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel +which are listed above. + +Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](../../examples/72_blackwell_narrow_precision_gemm/) directory: +1. [NVF4 Gemm with block scaling](../../examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) +2. [NVF4 Gemm with block scaling and NVF4 output matrix](../../examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) +3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](../../examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) + +Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described +[here](gemm_api_3x.md#collective-builder-for-collectivemmas) with a small difference for Collective MMA builder interface. +As in all Blackwell kernels, the `TileShape_MNK` argument expects the `MmaTileShape_MNK` which is the tile shape needed +by 1 or 2 SM `tcgen05.mma` instructions. + +Let's consider building a block scaled GEMM where the A matrix is of type `mx_float4_t` and column-major (N), and the +B matrix is of type `mx_float4_t` and row-major (T). We first need to describe the A and B tensors, and find the +instruction that can support the selected A and B type and layout pair. Then, we will choose the performance parameters. + +The skeleton C++ code is shown below: + +```cpp + /////////////////////////////////////////////////////////// + // Mainloop Builder Setup + /////////////////////////////////////////////////////////// + + /////////////////////////////////////////// + // 1. Describe A and B tensors + /////////////////////////////////////////// + using ElementA = // TBD + constexpr int AlignA = // TBD + using GmemLayoutA = // TBD + using ElementB = // TBD + constexpr int AlignB = // TBD + using GmemLayoutB = // TBD + + // Mma's accumulator type + using ElementAccumulator = float; // Always float for block scaled tcgen05.mma instructions + + ////////////////////////////////////////// + // 2. Choose Performance Parameters + ////////////////////////////////////////// + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using KernelMainloopPolicy = // TBD + using MmaTileShape_MNK = // TBD + using ClusterShape_MNK = // TBD + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelMainloopPolicy // Kernel schedule policy. + // Auto or using targeted scheduling policy + >::CollectiveOp; +``` + +From the valid type and layout combinations [Table 3](#bs_gemm_table), we see that only **row 3** can support `mx_float4_t`x`mx_float4_t` +combination with NT layout. As a result, we need to use the `tcgen05.mma.kind:mxf8f6f4` instruction. Additionally, in order +to use `tcgen05.mma.kind:mxf8f6f4`, we see that A and B tensors both should be 128-element aligned. +Thus, we can describe A and B tensors as follows: + +```cpp + /////////////////////////////////////////////////////////// + // Mainloop Builder Setup + /////////////////////////////////////////////////////////// + + /////////////////////////////////////////// + // 1. Describe A and B tensors + /////////////////////////////////////////// + using ElementA = mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; +``` +Next, we need to choose the performance parameters such as `MmaTileShape_MNK`, `KernelMainloopPolicy`, +and `ClusterShape_MNK`. + +`MmaTileShape_MNK` supported for `mx_float4_t`x`mx_float4_t` with `mxf8f6f4` are listed in [Table 11](#bs_rows_3). +For NT layout, we see that 3 `MmaTileShape_MNK` are supported: `128x128x128`, and `128x256x128` with 1SM instruction; +and `256x256x128` with 2SM instruction. Let's say, we expect to get the best performance with `256x256x128` MMA tile shape +for our GEMM problem. Then, we need to set the `KernelMainloopPolicy` to `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100`. +Now, we need to choose the `ClusterShape_MNK`. Since we have selected a 2SM mma instruction, `ClusterShape_MNK` should be +compatible and its first mode should be a multiple of 2. `ClusterShape_MNK = cute::Shape<_2, [_1|_2|_4], _1>` or +`ClusterShape_MNK = cute::Shape<_4, [_1|_2|_4], _1>` would be valid options. Let's choose `cute::Shape<_4,_4,_1>`. +Our performance parameters looks like below: + +```cpp + ////////////////////////////////////////// + // 2. Choose Performance Parameters + ////////////////////////////////////////// + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using KernelMainloopPolicy = cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100; + using MmaTileShape_MNK = cute::Shape<_256,_256,_128>; + using ClusterShape_MNK = cute::Shape<_4,_4,_1>; +``` + +After we config the main-loop, let's setup the epilogue. +A normal epilogue looks like below, we need to specify the output layout, datatype, alignment and PerSmTileShape_MNK, and let others to be default/auto. + +PerSmTileShape_MNK should be deduced from the mainloop setup. For example, in above mainloop setup, the MmaTileShape_MNK is +256x256x128 and the KernelMainloopPolicy is 2sm policy. +It means each CTA is doing (256 / 2sm) x 256 x 128 output, so the PerSmTileShape_MNK is 128x256x128. The possible PerSmTileShape_MNK +is listed in [Table 15](#epi_persmtileshape) + +The epilogue scheduling policy is configurable, and it is common to set `cutlass::epilogue::TmaWarpSpecialized2Sm` +to allow the epilogue builder to automatically select the appropriate policy. However, it can also be explicitly defined to +use other policies based on the 1sm or 2sm MMA instruction. The available policies are listed in [Table 14](#epi_dispatch). + +```cpp + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e2m1_t; + constexpr int AlignD = 32; + using GmemLayoutD = cutlass::layout::RowMajor; + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + +``` + +If we want to let the epilogue generate mxf4/nvf4/mxf6/mxf8 (i.e. elements + block-scalefactor), we need to setup the epilogue fusion into the builder. +First, we need to choose a SFDVectorSize indicates how many elements sharing the same block-scalefactor. +Then, we need to choose ElementSFD and GmemLayoutSFD which indicates the output datatype and which output-dim is used to generate the block-scalefactor. +Typically, GmemLayoutSFD would be same as the GmemLayoutD. + +```cpp + // + // Construct FusionOperation + // + constexpr int SFDVectorSize = 16; + // Define the fusion operation applied during epilogue + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFDVectorSize, + ElementD, ElementCompute, + ElementSFD, GmemLayoutSFD, + ElementC + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + FusionOperation // <================================== Pass the fusion config into epilogue builder. + >::CollectiveOp; +``` + +Above example made a gentle introduction to using the fusion operations in the epilogue. For more detailed example, see +[Blackwell GEMM with collective builder](../../examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) + +Note that we have first discussed the CollectiveMainloop, then the CollectiveEpilogue for clarity. +However, the CollectiveMainloop needs to know the SMEM utilization of the epilogue. Therefore, it needs to be setup before the CollectiveMainloop. See [examples/72_blackwell_narrow_precision_gemm](../../examples/72_blackwell_narrow_precision_gemm/) directory for full kernel and run setup. + +### Scale Factor Layouts + +The scale factor layout consists of a 512B basic-block structure, as illustrated in the diagram below. Each block contains 128 M/N dimension and 4 scale factors (SF) along the K dimension. +The byte order of the basic storage chunk is row-major, meaning that M0SF0 to M0SF3, M32SF0 to M32SF3, M64SF0 to M64SF3, and M96SF0 to M96SF3 are stored consecutively in GMEM. + +[](../images/M128xK4_scalefactor_gmem.png) +

+ /M128xK4_scalefactor_gmem.png +

+ +If the scale factor tensor exceeds M128xSF4, it indicates that there are multiple basic blocks along both the M and SFK dimensions. The arrangement of these basic blocks follows a K-major order. Here is a diagram illustrating the scenario where M equals 512 and the SFK is 16. + +[](../images/narrow_precison_multiple_block_sf_layout.png) +

+ /narrow_precison_multiple_block_sf_layout.png +

+ +The creation of scale factor tensors' layouts are tedious. CUTLASS provides `Sm100BlockScaledConfig` to create these layouts easily +(See [sm100_blockscaled_layout.hpp](cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp)). +The interface to create SFA and SFB tensor layouts is as follows: + +```cpp +auto problem_shape = make_shape(M, N, K, L); +using SfConfig = Sm100BlockScaledConfig; + +// SFA shape: ((32,4), ceil(M/128)), ((SFVecSize,4), ceil(K/4), L) +auto layout_sfa = SfConfig::tile_atom_to_shape_SFA(problem_shape); +// SFB shape: ((32,4), ceil(N/128)), ((SFVecSize,4), ceil(K/4), L) +auto layout_sfb = SfConfig::tile_atom_to_shape_SFB(problem_shape); + +auto tensor_sfa = make_tensor(aptr, layout_sfa); +auto tensor_sfb = make_tensor(bptr, layout_sfb); +// Access SF for for element m,k of A tensor +auto val_a_mk = tensor_sfa(make_coord(m,k,0)); +``` + +# Copyright + +Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/media/docs/dependent_kernel_launch.md b/media/docs/dependent_kernel_launch.md index 76eadd20bc..3fcbfeb2b2 100644 --- a/media/docs/dependent_kernel_launch.md +++ b/media/docs/dependent_kernel_launch.md @@ -2,19 +2,24 @@ # Dependent kernel launches -The Hopper architecture supports a new feature through which two kernels in the same stream can +The Hopper and Blackwell architectures supports a new feature through which two kernels in the same stream can overlap their execution, named [Programmatic Dependent Launch (PDL)](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization). This allows kernels with conflict in global memory to programmatically and safely overlap portions -of their execution. Primary kernel can signal it is about to finish execution, and the next kernel can -optionally wait on the previous kernel to finish flushing its memory. +of their execution. Primary kernel can signal it is about to finish execution, and the next kernel is expected to +programatically wait on the previous kernel to finish flushing its memory. + +We enable PDL by setting a flag through the extended CUDA launch APIs. All CUTLASS kernels with PDL support +will wait on the prior kernel to flush its output to memory and signal the next kernel to start. This means +they can safely be dropped in with any other set of kernels using PDL as long as they also adhear to waiting on +the prior to flush its memory as well. For more information, we refer you to the [PDL section in the CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization). ## Using dependent launch in CUTLASS -When building CUTLASS, you can use the `CUTLASS_ENABLE_GDC_FOR_SM90` macro to -enable PDL-related instructions in Hopper kernels: +When building CUTLASS, you can use the `CUTLASS_ENABLE_GDC_FOR_SM90` and `CUTLASS_ENABLE_GDC_FOR_SM100` macro +respectively to enable PDL-related instructions: ``` cmake . -DCUTLASS_ENABLE_GDC_FOR_SM90=1 @@ -30,3 +35,10 @@ gemm.run( /* launch_with_pdl = */ true );_ ``` +## Model-Aware Optimizations with PDL + +In [example 63](../../examples/63_hopper_gemm_with_weight_prefetch/README.md), we use PDL to explicitly optimize for +performance of kernels where we know that one of the input matricies (our weights) will not be produced by a prior +kernel. In that case, we only need to wait on the prior kernels memory flush in order to load the other input matrix +(our activations). During our prologue, we can prefetch our weights to improve performance for memory bandwidth-bound +problem sizes. For more informations we refer the reader to [the example](../../examples/63_hopper_gemm_with_weight_prefetch/README.md). diff --git a/media/docs/efficient_gemm.md b/media/docs/efficient_gemm.md index 4defa6d857..470c4eee79 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/efficient_gemm.md @@ -219,7 +219,11 @@ which has to happen at the end among the participating warps. This is because each warp computes using only a "slice" of CtaTileK, so each warp only has a partial sum before the reduction. -### Warp Specialization +### Hopper Warp Specialization + +Note: the following section on warp-specialization contains details that are specific +to the Hopper kernel design. Blackwell SM100 kernels have a substantially different warp-specialization structure, +however, the concept of separating out producer and consumer agents still applies. Starting with Hopper, CUTLASS 3.0 incorporates the concept of [Warp Specialization](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#spatial-partitioning-also-known-as-warp-specialization) as part of the kernel design. A thread block is partitioned into two sets of warps, [*producer* warp group](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [*consumer* warp group](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp). The *producer* warp group loads data from global memory into shared memory buffers using the new [Tensor Memory Accelerator (TMA)](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). diff --git a/media/docs/fundamental_types.md b/media/docs/fundamental_types.md index 311a80e659..e50cd38406 100644 --- a/media/docs/fundamental_types.md +++ b/media/docs/fundamental_types.md @@ -20,6 +20,20 @@ CUTLASS defines classes for the following numeric data types. * `tfloat32_t`: Tensor Float 32 data type (exponent: 8b, mantissa: 10b; literal suffix `_tf32`) * `int4_t`, `uint4_t`: 4b signed and unsigned integer (literal suffx `_s4`, `_u4`) * `bin1_t`: 1b binary numeric type (literal suffix `_b1`) +* `float_e5m2_t`: 8bits signed float (exponent: 5 bits, mantissa: 2 bits) +* `float_e4m3_t`: 8bits signed float (exponent: 4 bits, mantissa: 3 bits) +* `float_ue4m3_t`: 8bits unsigned float (exponent: 4 bits, mantissa: 3 bits) +* `float_ue8m0_t`: 8bits unsigned float (exponent: 8 bits, mantissa: 0 bits) +* `float_e3m2_t`: 6bits signed float (exponent: 3 bits, mantissa: 2 bits) +* `float_e2m3_t`: 6bits signed float (exponent: 2 bits, mantissa: 3 bits) +* `float_e2m1_t`: 4bits signed float (exponent: 2 bits, mantissa: 1 bits) +* `type_erased_dynamic_float8_t`: Type agnostic 8 bits signed float allowing the user to provide a specific datatype as runtime argument. +* `type_erased_dynamic_float6_t`: Type agnostic 6 bits signed float allowing the user to provide a specific datatype as runtime argument. +* `type_erased_dynamic_float4_t`: Type agnostic 4 bits signed float allowing the user to provide a specific datatype as runtime argument. +* `mx_float8_t` or `mx_float8_t` : Block scaled data type with fp8 element type and float_ue8m0_t scale factor and vector size of 32. +* `mx_float6_t` or `mx_float6_t` : Block scaled data type with fp6 element type and float_ue8m0_t scale factor and vector size of 32. +* `mx_float6_t` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 32. +* `nv_float4_t` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 16. * `complex`: defines complex-valued data type based on the supplied real-valued numeric type Numeric types in CUTLASS may be used in both host and device code and are intended to function diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 846cfb5422..6383f97941 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -115,6 +115,10 @@ usage: ("s1688" and "nt") or ("s844" and "tn" and "align8") in their operation name using --kernels="s1688*nt, s884*tn*align8" + --kernels-file= Same behavior as `kernels`, but kernel names are specified in a file with + one kernel name on each line. Set of profiled kernels is the union of kernels + specified here and those specified in `kernels`. + --ignore-kernels= Excludes kernels whose names match anything in this list. Device: @@ -284,6 +288,8 @@ GEMM [int] --max_cc,--maximum-compute-capability Maximum device compute capability [enum] --raster_order={heuristic|H|along_m|M|along_n|N} If supported by kernel, sets the tile raster direction [int] --swizzle_size={1,2,4,8} If supported by kernel, sets the 2D tile swizzle extent (In Hopper, other values will be rounded down to the nearest supported value) + [int] --use_pdl,--use-pdl Use PDL (true, false) + Examples: Profile a particular problem size: @@ -323,6 +329,8 @@ Profile when execution is performed on device 0 and the C tensor is located on a The format of tensor argument is followed by `:`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](tools/library/src/util.cu). The layout could be `row` or `column`. +CUTLASS 3.x kernels for Hopper and Blackwell also support a new feature called programatic dependent launch (PDL). This can be enabled with `--use-pdl`, and can overlap the epilogue of the prior kernel with the prologue of the next kernel. This can effectively hide kernel prologues. Using PDL can improve performance for back to back GEMMs. See [dependent kernel launch](dependent_kernel_launch.md) for more information. + ## Example CUDA Core GEMM Operation Example command line for profiling SGEMM kernels is as follows: diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index a217e0e78a..dd1b0c6fc2 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -24,7 +24,8 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc $ mkdir build && cd build -$ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture +$ cmake .. -DCUTLASS_NVCC_ARCHS=100a # compiles for NVIDIA Blackwell SM100 GPU architecture ``` If your goal is strictly to build only the CUTLASS Profiler and to minimize compilation time, we suggest @@ -653,6 +654,105 @@ targeting NVIDIA Ampere, Turing, and Volta Tensor Core operations $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=tensorop*s*wgrad_optimized_f16 ``` +## Instantiating a Blackwell SM100 GEMM kernel + +Blackwell SM100 kernels are instantiated very similarly to Hopper kernels. Let us start with an +[FP8 GEMM without blockscaling](../../test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu) +as an example. + +The kernel starts with setting up datatypes and cluster shapes. +```c++ + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); +``` + +The epilogue needs to be instantiated first as the mainloop collective builder takes the shared memory budget of epilogue in the template parameter list. The 3.x epilogue collective builder API has not changed +for Blackwell, so the epilogue fusion is built in a same way as an SM90 epilogue. + +```c++ + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + + using FusionOperation = cutlass::epilogue::fusion::LinearCombination< + ElementD, + ElementCompute, + ElementC, + ElementBias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; +``` + +One can refer to our Sm100 unit tests as examples of how to correctly +choose mainloop schedules. All of our dispatch policies can be found in [dispatch_policy.hpp](../../include/cutlass/gemm/dispatch_policy.hpp) +and more comprehensive Blackwell specific documentation for valid +dispatch policies can be in [blackwell_functionality.md](./blackwell_functionality.md). + +```c++ + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; +``` + +It is worth noting that the mainloop builder takes `MmaTileShape` while the epilogue builder takes `OutputCtaShape`. + +Instantiating a blockscaled GEMM kernel is slightly different. Referring to an [MXFP8 GEMM](./../../test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu) sample unit test, it takes a different tensor operation class: + +```c++ + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; +``` + +are needed in the mainloop builder: + +```c++ + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, GmemLayoutA, 16, + ElementB, GmemLayoutB, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelScheduleAuto + >::CollectiveOp; +``` + +We encourage a user to refer to Sm100 unit tests and the generated profiler-based kernels as more comprehensive samples. + # Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/media/images/M128xK4_scalefactor_gmem.png b/media/images/M128xK4_scalefactor_gmem.png new file mode 100644 index 0000000000000000000000000000000000000000..f1d21dd215df369520f3888012a165c0db217795 GIT binary patch literal 224698 zcmd?Q1zQ|l@;8hGcL?qfLV~-y1PdPA0|a+>x8NZ_a1BmycXxLk+}#HT=AG7GK)f`URW zw-6Usk`WiDP;#{UVqt9x1tlGkn2hjA^$^=1>}AJ?#7OCk*oHKXL_u={C0oHIhE4IE z09I7gdT>btgEUqS|vOS}J8LqN*C|@Vl@1nZH{s~wG=QPDt~&g zEpFC~D&C`~nO=TXvM7I5MQfGByVcbE?K@(zaKuE?-h={jZ0hbu(Vbi^1w|y3ACZ^e zAN#&3S%AHBDh71*$Pd?Nroq()&^par)bCdno888B!QPZsOlF?7wcKQY5ih230}jHVsr)%B;nX{!v)8^{ zXeeeVZ#0R0r@?#k)|5r_fN`6BQS4WrKxp01Iz}i}H`hJs*2A>+NX!hDeViO1=@fhH z>^N|14$547BA(7+yb$28e{T+@&23M!HA)DWbrJJJkrqLp^rxGI@$4k{k@vg=Yqk;N zdV&!s!V?ZX0)_u3fP(_QR0Mqy@k*3(5#BXWwi16f@OcB12WGib#vWbwl|~nuJ%Yfi zXG6qRsO=5XAHfue7|IlM`N$UH@NtADXsnbPgT%wR z4I;|P%E6iBJz$$fSqp>DRzBdoLHc0G_A!XlNZeg$G+4_XiUgfG1ZOkRo+}-}E@Whr z=z(GtIV-55+ik=50Zt$QN&-HInlGHUTS6e7;tigRNbewDKGb`0%eXK3u!kRPDW3-k zuCRi7lvq)tc>4*r!y-EES?O7iStNhRj{1J2FFbipu*K9wqW6=ws8U zQHmTJEF1aIR}Gv2|HniYyucZAXsMyU^EevJ5mq)rr&6yow%TKnl6Ux;WF3 zE0cv0^`ZLV_3u9KrM|Pr84ocgj3m(~@YA{J*?!D?XZP;>{Y1P0tpSan*h?lT7)Gs2 zmQtEhK1X`fzuPnfOR_aLC%1u0kyWuu{jzn70ypPKS+H35=&dE`EsRTeXddCzVy%fb z>mBbJ9vza-TF~^9N?-Ch6ynXNeaU<_(In@&NhZ zPJi=G4r#VU(@C3)^~ph&Vfz}Rw9WANuUISZR*HrW;#vC^$J~E9L|@wAq$AC|=`(2< zLQ`Fw9p@OT9lmKMwEklEmSx3eo*#pM!XeXMf2`GO&%`Ooqrf8p=nuC>HA%IQJGQm! z{S2(SZkRaF1kX*6v&~hHU*#m`CU&U(QVW~2vQoAB<&t=7L;N*5cFZ%)HBF=bOTG7s zV7-s)rB!OdK}ueWd>pu?i2_W zq|#rUx!e>Sgf;}KJI$f}MArgE@%8XyMud`mHA1b@k7Iv5+w$3xG!Qiuo_U@L+| zB_YwLxWwH~_e_hQIIM2B6(;O!Pt@p#Cs$|$Xv8+;x&^!ayvHHe%aX`?o8@4@58!M! zd#?0__FZ{CtKwYJb*1)N1$6n9`)NS&QZQt*ZCGSmWb;%SQ?^d}ZN`Me;t?`nGg+d* zh@NDVipB;-2Wr5Bu>(<%(2J2h@GNmtah6Fxpy`mgQk>{DO1V2c^dN2{%1D^y>+XCT z>>s>W)>o!cexAC{TlioQtsIHfTRX62Ol^E@OpE8myx!z)9k7pOE!wo@mbHeY3E?Psd#`a?;h&M91MJyNVUP`-tUQKB-;h zDOoZ2@jLU8!C|fA{;-3Cf1?notU>B}rl@#Wo)xz7klVL_j|-zEqZp-B8l-yVmD-Ju z+DR7kNoAYXStnWZhV%9d&U*(t)}w@)Bib!?q_s@Gdi$Pjw;sn~To1O39r3lSrrQ-w zNli3OcHELSM}sq?D{gx^3erlaWh@uHKt*16TKbEdv8R zNVk`<`6_4*V^`RFCAG4yaIxKef@_s)v%Y0zt^S-@`P~o`;cAt=l-2vZ+U%Rt@m+rA zN$u(NYZXLDAF>Wk?9X^!BKw%pXk8T4GgSE0yi=X|&Tksnm(1Rde;tp>kQ3luY-+vU z1Rh*Bw1_o(GQHAmbe`JZ0@|wIr^qoSE9kp6dv76M%dV(Bm6k7@{2~R-u-oxHde*>TBmXG0 zdV%;N1KT%(j>4YFQ^;F{C4Go8b#Dp5O`yo3Z{_9PC)#ErHTg$Zjt#javSI$@!oFm~{&|M6`(02} zRa`~}a#uBWG&Qw#GPiS1Wa+VlkXo?#r0J|F|B=tw&W6dz#O|{xle>-m?PG+W}_(u_ED?utvc_j*QJ4aIrE+!Tx7AhfR3JMAV zN0TpnDiR<5RUPssNM-KqY|qEc?B?div%FD~k%)-XZ#>NOK!RX{+>uluC zXzN7%Pa=QlNSHbqJ6hN~TiDrB{HAO4+0MmTkc#SeLx27L$)~Bi#lKs!b^4cB5CNHg z=PmV|{3L_Lr;gt*w)1^!X~zjFTF;(yfC{C7=O7B;s3 ztok2W|F^1|lc}S)oeiW-XQ6+0&A%%DXXd{u3NZf`{XbOkPd)!56{2V%WC7;CI!y>! zo=+MM3Q7b@Mnd$HJM@Vjf;+C{f^Wj=4cRvlF%+8M^)y5(72MYnUty@|vhgotAX>#t-4wd~vj0JS z7xiP7kRIKyvSBY9vB{2_Fpoe{R{sgqU3&!|#(&{ze7A zX>tzuBffABwCQU7U3C$HB1ecCDjdJ=!u*#8|G|>T7%N4f-uCG>C>VH@dG(773uauOOhzx$Wt(~C+dpEZMH11V@cr>CSDy#W2 zQD0x7$GcM$e*PA*Icq>WDJv^$<7xehp`Bf9BE1HcprGD}YKO0n^U*BQZm~GMZa-qf zQ;Jz;SAl%mFFGT}(~?m~vA=Ko0iRtdcf)UI5KG{8R&{WASnTQU6pNfZ)?%h$GmP9< zB$-L?;$b_s$DHHy-N};d1Gkg?3&`i)8XCGJlgnngb7Dd<<9xNL1{*Y@ z8rJkuedgro7?+f!ZtZGi(T5k8@b~KN?D!t=84i4MV*Qg50|liA(jub%QEX2NyICy{ zd01%Z=NAw#W<@mF zN=_IAp}pg3!-*j5q?~+?akrZTWe)2_Os~5WUa(h| z4O03EI%+K>=wr~1?D6=sSimkll;zY!)Udn`sDA4m&XZ{G^H3#^YsWePAel?tf^+zx zx7N{xT~1;9G}_^idsqjntbQ_M5+RXCk`>t(bc~^HZ?W5AO1}K&Xqyj2g6}?$Xjc(f;X1R z+0z|@!fG-7hC!$PBM|0bt~6P;0d#*eqi7X`Xy6N^N#(Zt^=(8I0srFuejv;51?di_ zGu7vgztU=f$=$=lXsy{rv+8rm^txgSn`B+s$VK|gw`raXaq!FY<*L&}3i?jRpFYaw zhW|FvRe9$@A&kNMEF4;;EX*VE0-HaOqAi2|Pm z0p||Cp@)eYz5i_x)UdIlw*Au&@%+ni|L38fgI1GrHcJiTV!mr<1)d4x-2uZ+gKjzR zz0egnlqS*d_PMA!&@G01h~CRL>v3ppah||7s$_m!{)T5$m%g_9T-i~H#?simc*E5> z2AG9^oUoZ$Hib)2_owD?-)ua=qM*g)Z3Dd7e6}rlM~u!v042rU_mj6nPvI>|3>e9n zIQ5tGsT&X5Lb?PHnSmfW8iC4eBQXFki2c$>c0^;IuNDRy>(5G65&;>aeq5(!PM+wH z-@$TSDK-dv-vKt7%oa`Em^h@Cis@Z2@QbF+bQjY?j3N_!q%-qNKM%pz7lADOJSD-d zxJ_3TWiIxv+}zwW2Y%OM+{5QkDBu&GxN>xNk#6fm7L!3t`3qrRkb8v1p6jYxmhULx zn9aAgKl(TP(z$tqytw$t%p$jD!0uCb`;5Uo zHDF0@#*=szHlEW#m?T-C=xT*FJ#>crp5pk-^gAtWZ+$T{s|4qr+^-CZ{X#+b|`!TpcA~Lt^iiw=$A9 zYPJ2`(%oI8bM;p%->I5lpHX%C-G=>^;-YCGRl}G&L5srl5u2;8Zr^FHYjGQEJ~N5= zWbezMH_gMff4H7DYE(2r7o~TTx`L!TCI$0=^Y}*I&%{#n^6wIVb|?WyM%jzIAnjW7 zoY4a<)p4??B@*yi%Xw&I(T3+h40@^b1fjsHZFI`=lKzam=Vx zY^ewMPA1--=fPN1UCPY_aEmf#dM+l$Xt8b)NAZ+tN|zx&Uvs9X=(O7MdZ+M5PlT|c zl$iTn!H;rWo}JZ4n(9m*SmOb3k#Q;bYnMq1fa&Z9 z->8*6PFxDq{?gc{0R2ZBbbg`v>gqW$GA-uCtThpJ82s6EpZQB-otWvXAJphr^qW_b za&^2Udgemq4CKSwVd~OUe@$VWtV(4YfGT6Do0BD+lek#cXyt}LxZV}?d>3F zSa^Rpzx9O@wNi6Dmr-rt_wTRkE@h#Uav+vTIKO^<1V#NJN@?#|mo`kjT)UX-%1}~8 z1^dfb>d*1yqyp1#f6dt0X7D5*gZ5l`Ei~$;nI8T5%&fUZQ!a>g*79ObvAgS9KTLb> zbCUVDVoq;3W%Hf`eT@nl6Yx_`XZ&GOE_Fh^8$?{pDJYj-W)npN{pXXketk{;$*6&w zi{pVd7(jNK@sbKv2pD{m#Q-IX(q}~XDYO99Ug`SD<547G#a!Q4uCfElHO8E9Q6`r( zgXJVsUNgC*7#+A?KoiAyU1i*vE}B(3CT&xXYFge?9o741#~5&An625B#Vlalk9u=B zBZUNRO$T%az|ML5&eu8_%0aF3RC)-9u^RUmwE;}Ixy)_D>NuRpoOH$(;{O`D*|hLiKH_Rl zCyFmW_p+^Q@_=WR+-S)DqswY6=NSUE)_c-I9(5gbJGlBtQgRdA9+eFpmAI}$W<#}L zPVDfODU{V4--Z!+ZqqR?wC++^E1N16CT8n*)dOWJGbj(9F(dK-WNn=?iLZSUz-w6T zVik0R_PjPy$0t8TYX6_sID z{o2KZ`A?Gv5zxlh{6>pBxOY`jn>W{tJ7wn|b`hBYE$!1L#(;(ILeZ_43B`|@OC^WH zOPQ=SCIma_G5aSgePn?~J;$})faj2_<%AkRbPv&t_;b&`+}Hx7ofd@RYjgx(ikXLZ zY|&sD($dz}H~L66L$ovA=zWP!i8&{!ds&f{*60Ak4qDgG?%1R=i z{~Ct3EK0fkfU7I(cDwgCPN`Z4-23#Wi|G@N2HGjPhkMm!8QhBRSqM;5ccBYwT%b{& zNxmjAzb`izQqz1v`m=xb{Oq5UGD3FQqU|{mleEHq@)GrGCSj*7YxLfN$<&S$p$eL*y9lh~KYq*G#9Jd}8r$qg9(%aOu4n8k$v^)p$I-P^3MyPWNb z>a^Jb=uT~|`-hk&2j+_Uprv?9J9Lgt<#!l~b=n9GGJ$_qHYHfp#rP!wo%+q$;=o;V z%WP6`s@x0cdg6zG$5lr`+qBX@4;=wkaRJt;`-cQ}=)f_`HhR zeh5Ub_uqShmd$LMySj~-nUSPrWEa(Le_ZiTSR%&}`ppMDncGvqDj2vywH{F(Z&IhD zP@HyX->=6rzyO!um|NxK*K5)z%lsI;hhpXorx!x2edG}<#W&|+*#12O#i5%MN=%Bt zok)kfVZ4*xb*;7wqMyOXMYkG&qe1yxUzC^474cX1aTyyv-W>kmbD=vpr4{}6!Q({@ zl_!2kz)lhtWnKR^GLw;Er^eq3VkWi?^j|=G-!!J9qv5;l56l|lxV_PV=%L-I?%H@t zC`n?%Sq#OD>FO%H_~iZ#-vUir7=N>+jiAGG;V(9&q}#%H{aU`^Bl|$)-uu%YLiX+4-}|{ytF`vBED70y-6dJU{mP7BsDPy8lr4zbA4- zRd~(41&$Z*$&y5iCE_&wU80iF)*HZ30&NB5Cm2bBMlM1Yodbp?Y!@4t*PpsOIeRNz;zh5%ve?_O8RCN~Vd{5BzggjAXJfgvkReaPuTK13OOF5U9nYbJHlC#iL{m8?pocH}=Yz$qwiYykl$KQv#h zav{3j-3^Wme*hK84@L^s-nnl9-E9_#S!wi&}^7m%V~h3>bsOziZe?%(5&>v?rys!X~XdoxsXl}Y$&Ba`qu z2dR$SmFVt0?hYy{&K%iWPsI##O8UCuRoC13Uy8jU=~?cK7+$10z@pzCmXrUmcSK{~ zUMx%kg0Y6i_3X7gB`iiY*ZW*Itw9&W>DVVaXF!Z=0OT8r zuWD?(1kM(5FP5w6*9_TD&|RX9Ox##!96&uyiw@MpRS6$*R~gPj*m)nb(6toiTJ_ zTj^kx=}V=`-kTXa*q&l}lJDl&p6?y$TcQ1ukcibxp#E(Cl0`j8UDm}N(pBMMPlJi~ zT3)Lz$e`ll1h8B3l3IU@>1d3-!MXEc&LDcoLRzZ^rT>RSql*pF!X4bX(l6Gh6PkP* zK9TcC)W%zLw1xQsyPVZUiTtuvQNZw_{rR#ieR8Weu@AkOWwe9M%Ik6Ya7+X^I6 z`HUMl5Ea^1N{_DKO~5Lb+%s<2nR1xbxjuxnPRdU~&j;TEk0FVE8vSM`3wWS9OXxsL zvbt7vTyHe^^{d$4l9M#Q1NOM6=n)^y+)vJIWGc>pL^uO7OkSQ#tSzA3?2!|@?$KO1 zjB`k78B*C@416=dZ90+;MTXy%%5l&7Va=I3{dksXm4}#o4pm+1h@9F;@tI+Fl2XF^ z0%>h!-M)U)ZZe9H{k5twnYV&5gR8Zl$*)J8n6Z`4$?PguA?_oRE!x%4GMAAJ*L|Ly z_zN;OfZ?L@lVc32$kECDhPSZ+&%**d24ZgvsXr*4VdrEmX4XzPMD>2%9+4u4;*Ra} z09rP5@W|kVwA2wkd8Z5Frn8gAe)JT+Os8<<>czHw5h-dIynNin$t#mzTBro!D9te$ zc{u6Ic&Nyh(F|r1cA6gcZ=9Zy-v~&#YcPhQToan4CD*N2UbQhC629NJX{BpLd^j9lM_pFLB#(8=X& zPseIA+7*pqAGtH$ZL66TECO_J3LG!c(Uxr09HY?#MgX=RI%r3_%bv(y!n!e@KH8mv zMOT3dXO(<>N&t`8OdW*YGvUZJJ8;P3vrboez1t;sX>iR<0kiUZ3?ab}BM!~rH*c8N z5jC(YGl=oM3r({kKlE-4No&=k`5Gh8LB9pn4-E?WWVX9GI(RHM*uOREXK1BiUJD$o zQdMj7NLd=>^Tf++aL%>caSz6DCu*8Pdb49jxfj%QLFj>UvZm*dI=-(PC5etkU_vU4 zGLQ|$!->@2gai)D@8BvvKRdYJ>l|N#YET%f>BJ zG{MD|v0k}6#Xfq^FU*f{xPn2wmk@Gsum&CDLI-}z=P$eQMIr~0_}~-8d2Ddx0@CE} z>v2e|o)V9zp7*w!=!7Fdxt9&uwx`Q=i>)5)R8&+-Ev`1s5dR^=iSYF(yS=@A<)9YN z5#I7^&X`{=Tg!?IFe$l zSCVgoIb#|4Wj#OCrWr(pXIAc+efhqqK?ki&e?*KIY%FZ^Aqd*0)@i_ZHCVg+nDT7W zV-BE+(gECsJ9vD|F!%7gRQKdBc#a7TTBiCuq+gziavWQH73Ak5d#f>T%kRl!b*NXc z8nBTGmn4MGMbfiNx=GMxTr%y8;EdX}emFAum?Ko_RsIJ7)YWCb#{+Jm&~v;zk`HS=oAWoh?7mtsc9Vx~;UL zS8R2MZgVuwX9pf2<2%{TLm@e-ch?x;Gg2CjBRAGs29Ne&nNXy~<{K;&XMy7+!K&v? z@0huJeQpt2ZBKu3+$dN01}s;sBXnFgSh(1~vFY)Jtiw$1sx?HF_x5vFv%VH2t6sHp zF{<3BQjg=th-(5m79D#0!Bl(OAw{I@?|XR^wN%_BYDO5`R3j-+4>ZRi}(d zuI^#q_Dp7C14QT_!Jn&JoE3Vm0LsqTwKrx&mSR5w7YRI&J%(gVC?}<9YVI;ad~u~~ zM!Lr5IF!s3wZWhr2@{6QipT1IeZjDmR3x>iqUO@);1FN*yL|bb&V6g# zo8Y;Pz5HO`{v`J!wrXtljkQItRHI)r`)9txuhQzyNgzsb=(2o>A&!bHS z+ufg)$IEE@pc6hRP%ty%tc!b5U-R4^H@=~ByiGe+zgBg|ge7V3)?0F< ziTd(QGQ6*HD1K^_#z~A}0l_=~eYjM?%%1au<5K1A17kKyev&Hv86-#Ia|c)VlZ>|n1pWO}vSbxqggW;;WHRZOwZf)Xv>KF`{m@wsJF>Vu zY+8y`w%_p#-4ePgvtsY=c7_Ok@u5e^jrvKdQyUVncU+iYKThd1hV5zL-r1ycF2wJWPL#Zn(SM)u{?Oo&V+@bEwov$VTLUnmt zIGZFs59fhYReArOWvTj}>-odZ)J6bfU}ohqOl%$B&DXjE_RnbIS~EAHbqqc~%ND=x zp_l`p+8G0-c>uZJx`*-sQ)_UywzfH@@-#Dk^RrARe7L2w-Vq))h40wC$K#}gBeq~k&%2r61drZ&%e_;?#`Y`1$J6G=t@TKdb0Z9r*sISShlxqoDqMO^H4(?jZ( zR>WaaA05d8&OCC^J7dwd&%4Dm-onIahe{vb{4|>(o)t>&O9Te{C7b9z&GjF$b8t%D^5{%&uUwnL5biL=t%I@0-lp!yM z@EXNozOiv^s!-5a`1Sh#(O6<3=f(34k#*`fn|~^f^X#~fNAQ3!y})d}Tn@z@y)2#l z9Ttc&^Ht`;#Q5_l#H~W2tE)R&D7(5q=z*W>3zX+gZM?YFeHBe6qkaf?vmE~=ZmeOf zOShq|#^O`c5IuADb+zy&-ZUlxx%Vq1o%)oT?lvt|cekS=uAmBxF|}&_WcA9gRC5t^ zUn#yfEvI&-l92`ARJ29J6=wPC?Uq_lhYW}!S1-%u>^2X&wxStdNyP?cw5ql5~op&I|hP}nV z!r!=^_1hG7!zDT5%mqH(7r%$u3k>ZyvbS82v|-qBGJu!IpX^Qqknq#4Sft$qs`Y4w zK(>I4hlEvdaGu@T+6Ki8>y=)%oe{vD){!4!c>o)cYX?E;84VV-H-H})wAe>dwmzhJ z_41PEDce=x8MR^7wRpRQpt|(2vqODr_GWp~n3nktn2LUAO;%StrZ^`NrwbdhU_QOH zqOp`?N?`;YctOavG@1-d zTwKsEan2W{$1wp&>N0aSc*CSyx?X`K%oKF~1Nz37@t`6zsG0hr$t#Y<4)?QADu%HN z9Y=!C(e~XsnR&A>o*4JI*jg+t%LgVMc|71h*!AKu($DZ763X$A&Wa873g%oH7f$%QQr} zz}vHXIEct){Q*PXFU9{xkk&jCJYm|cq*4GgRGUw-Fy-NDOCXckKSn7akZuC)^(tPx zZCdO!M?4M(UlD!j3+3$v>ntH}s}feBw(nrGAsgdf@Te09cl=%>A6)9_$O8#~wH0uQ zb^|zXxh|ZQuM{b;gb>@NW5YC?R6T1-r^N{_#j%>Bc;{6<7tKk3gpDq6fobuGkljTM ze(xo26Iat*#fLB#fKhI!Y}egZ($QNE9Dsi}tR8Zly_s(lTL`?#qS${ha_rz~smGj@KBok^&e`m(aY`oenJ%G(`t;=TR?BY>|n_SEsMvlOWj ztwY1+WOZV_0WV4!q+^JDf|Tipe=;=SUmmDmxr`iJ=Xn!cS8LWw6jPml_Ii5VkK3RE?=2_sr#+gEy z|D5{eS{1TMb7h65@Ar<0TUp6lOL$YX%yWQL>Q%{nG->&|9PGrN8hsd3)9M+D7BrF3 zmRO&W@ynFttN%oi4)Waij0K3C^w-B3kIW9M0hY#ERAmoM!-pF6g`Q$s|C%Jw;IdIAV< z$Xp$|eUOb8EghcOxmOTK?8YB#9I{*P7(ho^lUd0J51c{9V^ls4pxl6t_^!q;XYR73 z5V>L3H3NY*94nR6gnamHn|ATsy^_yPzsui0Biv~Oc7Qv5-=wD%ApIyc ztMR)VChu@yzbZPr2KZuRk*&>ozsPmG#PDc$M^89qp2$tv(~SAabZP6+Bv$PLk9vns zzvz$7V*^PF%bUeqK99+K?Dcgt98@Qdeyb>^AeO7$f1%x*TI)u>^`ce50kgzm*zZeU zCU{-H*SS7|z~lo@dD@~lyWTa?F-cR6arR+vkF&(b1Ib~>-!6=6AZzSmdx$qf{29|- z03?v|m=I>Ct(7awI6=~9437PhyN-lsdQ}szHL?zG+ zr)A!Yh@7ViNwyyuQpGvR>eJ~sju;ZKM=)gjs;#4t8dyw^^T zu_3$(ockdySNRQAyfDu2m5?xe6Vm%s$?)=vBRH$avu$au7CE2D>W;{p3@JLBHADU< z4xK8!pj+p|X-URSZs!A~ooNikG%lNq6}t{Cn|M->6z;SeglMvyr1Y8eN4-mtiMJ8@ zHY?d(c^I3MJ&Hxd2U7yC%Y~jk&5eybk$UaHVP8qRXCFeloDs>cWE5}^5i%ZG4{GbF z6YC81rgjd?ue+)W@ARgU^i`VqWuDn@$BZRP7`M608_KU%p%89jS*|Rv?W8Qndn?>5 zHm`dc)2C?TXkJ^apGIgHNVzgtOk71>SOU=H3tKeR`44J$K2Ulmw=Z%ZBKS2%~qo}a)skof|ee4-);Ptz`h z1T2;*VQxVugI?&;C_{2T!dS4uBGIW2dh$`E*bIqq9@8aAoqn7$d2g7O&0H4s{>vg- zZ?Oz)Pv-Nfd9(PbDI8JS*X?v{Ni#0_|^WbNSqJ+_dAF>ihcJQ4EbUeQZQuzQ?zvuE|Wx?r<83CT)!(79j4kR9A^ z&B<#ihKv#4cH1d8JCUbQ^ZQw-T`B(sk$Wq`kFY4SZZ=rXp^21Z=O6=<-u!pYCrf}7 z5OEa$jl%BK5vtW|#?54jFE6cQHcwx>_lY(lSf2pRdzDu$Zr9h>2kixPU3`_Qh|TiE z`X1@Ju=hjs&(rs3FP2)o00#~4ZC&08z(|VJ1cl@1Cr^WK)RLaG@B$9_tzjwc9-71k zPrH}lsaJ4gyIJxhS94-*ido7&Daj5kB2Ew2<{SQew2ZrMHQM8foSA@fP)r@7K=_8D zOIF$g7us8b&FNf9E#y}7_&jXo3?tJww4MU5j6deW^kqITc~V)QVB zVHo5eAdzx(+7p~=F?G*}hld&#kH8-P9th;Y*;*D+@8(GN#A;Q>&YoCcKreyS-o%T9 z4B=c`e1eU;YfWu5UkW>3a)<=4uo_S$%3 zNvSWn3-qUr&6V0)^=hpg){--Dr7nB)SEw4nsw3*CheoL?$$nLa8t=QQ+lt5+z*Dj1wE*st&2`gf&w9|Cty$K!u@sUw6eOY*t zK@tGSrd&+|OCQaD>FLUl5P2VAJ--Dp#x;iTLKJwDV zm;Ega1WC;HgGaUAvO`aDxaM8@AF;pPogE3#N!Mt59?xt8V1NhC&$+)RCnAGdIm%q{ zLQ_S(|0fvX|BYe|u=0n*zg7gKy~F%ZQ10K+n*rZ`$5_7kveWu^xRyVQJsJxH_~Nw3 zf%^+t{ZHV_pG659HuM|zuiI-G85zw=L+DU+(yz;PR``L$WMnaOr5e;UG;ltTH=+YE zq$odi$jW%D;DI*_+_lUnCnpREm4c4^pH+cm}|*Ff}lE( zgN1{p%V%7b8kMP^qWHBawuRH`WCjs8C(eupY1v%nTYy(uO}5CiS&q_ z=dM1f)*BsdE1?Q5P=ix0Z74XiGVy-H6>KMBtIuSM z+JrR*lV)L4QZiUSh%4eN)XYiZuI49b?|_2MmOE8?Nm|mO|M6?**)@2wZBX*5nVYh= zqG_JN$|8cWYSHxz;k1*=WXU)hBFahjuua4RaY~bgHfav@sWO1}HE%yd^EjWbtdRJ9 z3*+rRajgKvH30#+WJVPF`)B;71qogmH24pT%nl!69sk^^aW+^eYF`b=%zA9E_h`Qj zpVRi6v{iwlii|xgJWGw`*RE|T>G*{r93xjeL@?Kkb8;$wzHsqCP$uN^`u&pD!OS~U z(qk?qjpyA-Ta%p~bH>K^Z(gJMSQT%Zmu$yUKgFddA_)QNe5qYt?t8rAWp96oCwo$w zlPMu!T~=wjZec##?&kQ-;F)6)n2f6ww?S~P?EFzw2A)bBah-Z88;lBBBM$s|>iOft z_RO!5{Z}|)hlE_Y*@J8aK$tu4;a3PJ2*WDmar3YRtF3@?#%+ z3UK%_75A69twWMrdS>e;1`K+;oCc8VdNSL~pBfI!hx-a?KKE4FNRP-$Aw)VJ$J*xD z?S3rRjcm?AysU^lStYR+ZX?FuVy0whe4l7hzFYheg0>N2TsJ>nd{~vZ@0{pqV8g7d z_U!d!-NdfXTrOX)&?9yekrweOZ~KSl?;~t9DqU??$>d?g>0q%Ier|vV$o)1h2gnEd zBJr{C@YJO*3~J4>`eVssMT6ieDJY<@u&~tOz_floB>mphLgc(UNZ`E|2cb)W{t3cD z^cK#OP^eg_GQM`?8TF4pyZp^LuB`HrUZ$MA2Q$DKeY>vr%(b`HAVU1Agszn{aIVXL6M!y0@wN>!jVSVdzBNb^ z$5$;#Q!OOn*vD)xXfW@E&$JWcjXU}zlMC-IiN2m^GM_E+Y|yJ4XMyV{DCaT&0@5Mt zgYOfEH_MG32by`=<&ko9yb~a1jg$=~GwfLGzmz5DCAYGA#$nkqP z@GbLakC=h;e5a9c*8pG%h~WHogGP-`lWId zM>3_yrF(s9GDXC(uV9K!Fpl?)))NeJ@sJuT7(EekMCCV5zYwK6x;TmkX}es4RxQG1 zTZ6f7bv8=tgf|0(j9Kdk?8{%iez*xuMBtq#U-=rFG@ff$;?D8rq#he|8#*U4aRDJO z!B6xMArsMKM&I6Fh5%TAkEQceY)XHrB)aItPnayXpyialTN<{z-0j7fN{6qaBl;!R zi$#qz#sX=%BiR1I8oA&?gZ1XO=p=Aig4bg_?*4cb3Vks`lhbhx>~ zGR5k2@o9M`ZuNTdayMKL7F#eK6VyQi3E0LEO_`k5x74(_0{0&UVu{6&YyH!xxr`@( zYno4se4Or+OpSs%lpSacgRj}>UA`*^&*_K%>S!jZ?Xl_K8F_-$S zZ!)Lk{sBUjVB5JEj|~oW0x+00VD&t1awd;wh^v63dvy4qlpgqYBTe6%0kJ@Ebspp^ zV{I#mUQ!jGc(~gMDO(A~dp-`OxG*HAWn(}{wCqJgQhgZ5{c19R=ql}Kq-@RAg#!uQ z`}C<h*!vzkulZ+EDh9Xo*)~Xt}Nae9Cflc^u-m6TQBvP@A8FtVJ#@Jt95sYTOrh0U zWI${qIc&$%Vt*U%B+hShFEBIPzL=Tmgy@gso%;D2eqv%2<10VkFGO%0y;&{ExR7w> z#X8g+7ozQ*pZ0Trq*DWN-qQ=<^LfdAdRPtdo()hQ=|#ZMkk|a;sJmf`CmT8)HD=Ml zGuE?XJpH2n3+Kk($#AmTX1p)XWRr3i)aLkwkBOlDtB#M4ACS5j2TL}dZk-(5?sNVL zE8(#Cjhz2*+^=!pE7XY4Uuidhpm&1&u)+d-IJcm{gXh$sO>%|K&^#R*st5L1h+`%t zA%a(NR2^0i#K>p;=zig(N7w~0As>*g$g70lN;k&~=XAX*4R%tNmUR6wq>Zt~reIa4d?r2UA4|7L!{Jyt~!Ba*3x9cP#Athz2@XkDpsd{YPgz zzk+?+Td6Ikcp7<|ihb0;rM6|(P2CfMVKY^Q@#!gLWrih0vQ_)CO z@$-@KwY}BIQ|(9%7X(p4-^Wp$ex3g2l^3`A!2C35Ehz3=qJQGRUQzqW*^X%)Gx^u$ zm`Q*1)V&s;HT_;mlC5d^4w|FD#D_u%7`na2|MF?(Egf`mZ?0$e(a`~fC0hooUfv5{ z9XLUwx_NTzfcq{c!6619<~U7%>5=9j*mmKIz;tk93V z2Wmy(tbg~-%sqSBGV@Xmb`Kw*z*LdK>6uR`AoXG`pbV;;IcsWq42x@jT}1mNe(hYT z5T+F*bDR~}_lf*o0|Vc!^nAf5K5R1AmR(y)iDHu>Y~A-okUfS0wlO54T=dJA zw?v$l`Fk%2aBxP$-`~@iE{r#z*+wd*W{n2&hnY$x$9L{2IE8pwc$7)HGu`{-muEGf zaH0Y|Set347}hqE z|E%X|5`i6RXX9(7m8Mhd&H4Y)^%hWZ1zERn0tA8wCpf{9;O-J6Sa5fDcXth%;4Z;} zHUziE-Q8*2-KFXNJ2U^AZ@%|tSc}D~MR9LY_jKK=z0ck!=yv54wB??$%yM(QGE@k} zdE!1c_=IN0xEixJ&z<-9L?1V%CczcXT$~!t&o|%NtcVu?OTSWpc%$YRC!X@@4ILG4 z=qDxh17Dq4UbC3ndH)=~>#`>@tVTn<_lv=tPf)|M+pLK=ARDkQhGtF7^l=o?&3*OL zQqW}d`W@AzUt`bn5<$dsv~=K}*)>l*pYRmI_{yV{_*Y#ILzauC6zvPy_&A|d@$jA_ zEoEHs{?4xHkj7%9x6o!njmDgZgGuPPICr_Mz90mkI>&sa0)xShN3);)W)up)+~BNx z?UPnhbxwCjx>HGyK%<>GSS{`+2U~LfxLm52Z>3nFgOhrSt z{3GDZG*$!4i^V78Mxl!p<^Pf{CA&e}-lg)I%m3FPM1*q=| zAGl;huJDka#cbEB&%er^hxHNac+?EtjQEMU(T+ZnY{KUb_Bc$&bUf?xz!KaQ~kQj{43`%>MPY`bcG`@wVqE4QIWfbqBv~ z8RIVMF?#veuXy4&;kVw+0IDI%y%_#8B^8r~T+{)24a6Z^w`eVsr~7raqnD?ZaT>y9 zD@DpNfr;JR#eIE-^>#Q=MP>*tb}yscT{G!Pq5a`{az+DawSf*vzE>gMFV4I0*4KXg zc=%)~Yd#H;WSPsdl*>Uw;Yrx&vy|MSe!NA7G zM!2z!n)7VEd^89KmE-Gyr#d{3Om z5jEL^l*o2p^uOG^Rz@XA_+~zwHJ-EK1+>kRq=02X{Fbj!gQZQN!ggjkOjjFVB$m8K z3W4eEpUT=6D?3O1!xL-(L)QFb1T+V3I2(nlHdu~`U&eY^VGPeG)xM7apd(;WP(yDK zM3v@?;ac)|2mK-}1Hx)n9x-8zWv@ZHPkr_)*fTyC?2xA)$T>l3%*tzc72jg2eG-3= zbgj+D|F{i2**E8(k!?4e6Hzx3SI{(3VcxnmyU&v>pd?&^NM{zz+(>)`OmheL)yq=!4vnLR#9mo{xKQIi_Wx4lCK+it#^i zccakbG|9;XFmi=}m$)pE48|MsqEBzy=1N|}W9%nlJ09QX8uK74FyE3D?$)cp#$Xp# zS_mVlZoOLow{3n?(dyM{vjxnwt}R#rQ%CO`&_b5)C=;KkePze*<-dlRI7W2O`?d|w zwF^pqoN>+B2bKA|a8&AP4qeh|p;hB_zP7UmmRX&2y;Ldzem5`jE1b_w`*5AdZR#PG z!?NRE`L6oCOUA7%x}JUQ-|^&=_Vg!3@p~jJYT!iEqmPhnT%8$mULHADSFgT$rwlVY zBqM4oXu?}X$@*x=8go1rZvg8qAMU0z95`QEaKSE?TDi8)j@?1`zN7gz3ss*X+wsDm zgP`RAplApp`a2Smis{N16PIU*5qy#)1Tuy+2{xQj{*q$0h#aP+GU2q&1NPp~s1GG3 z&yD2$65MH~&Bq{Kgm)ib@-C~;m$-+!j1NeaczCu0j`xFCm4wAmObPXR6=%iCK8frL z=Kt{m*o1iD%(mzHUYin0$taExsw2q*zq+k2wC~?<<*Kx)<|xVOE%~{p&RML_lTsQy&8XO<&>MeS?*IOH4GfsFsOuJuof;2gMw3m;zQLB$viLS(R z#ANNr#00!txt|d~Vu%eqac?v1nv?|UhG&E7(a)sGdfTH$^z&W3yM|*eVopxRN%l>2 z5ZevqhQ4(M3`RN`1 z#F2Pnd~`bJnrY3+F2)!z}cx-(%KmDa;tjd4YW;Q+4pU9c`?6 zfcgsJg`)nds&0?wdRCiVDl<-Mc{`R015}@m1){60#c*gBr)P}H=A(PIV^PE}xqz(@Zv}qyc#J@PkOI|MilAP_CWMKKGg`&Av z8U^El%t&Z1)$s_&+~yJHNi95n-0|w#2?`xnWo(_F#g8R6WMrz*Jj0SMGE-%Pr^;`x z;#EeB8p>D505_S+G!G2)yHozH^t3PVnehd;JJ64gM;?8yX*032zzHrPNf%F(xvK!9 zplp`QaBfw#`RPShXit57VnGe{oPz^0sRTa1N4fKA{Atw9uKCWQgZQ)b185leCj6huvi{=w8PWh2)^E=wxv$f6R1TTbvo)~1$ zE*R+3kA~w1NOP^cY?Nf%`2MjB*oIz?hKjZG4}on^7ege!dnPa_aoQqs-WRmn~06dDXypS$@l8CMh}uWvw z;cn;YAw0w$`gR_^Eb?7jT}6EpFS^D zv1wM>UfTy=|8=W{zx`8C|0utO<3(acdAY_+hjAOP%UTP_w^d|BpP{dui9l}?g>sK2 zTeYY}UpH1(?^#T~{F(S_L5kS)o|Z2OE8>0??jx2SjN zzWiF4CX&FK_qEtHUsRnjxbZF`O%07$(Nfly_#Y-nwwp>*&ZF?d{Txw=S@VRLH$tiH zBSFXHAK2B06!UUFo~8*xO3W>$T$_SvIGWk;9W`4Ix3|7_aK~xC?A2*buRi~f46x8< z8zqU-AMpj2-AMp6mTL@?pRN7zA0L^{*+*Zxl5;%%GBaIF4@PRXBk45SP7qqNYhRaY z_biEV@bsVQOxISzMx^aNI1bnUqIv^riGQasr0$`>U#|b(SnEVHkRqL*s z{SUgy9-Y{rhfS|KL$mSrFH2@2N#{&*uARTKOl2DRdEgWqM(*Hd1Y|M1fjZ-+{-tpg zW6^Nw9yGRhca{IPFc70~z#~e!e_);zBAjfBSsMnisj0FsiIxa2VKpuob>&BD0ro$b zqyI?SJXv&h(n9tO-*e|C{!7dGzbO@%VGI9zLtcIOF-(duitnt_J|3XGByn32_P0cPv|ZDMMjb zV2)sAC8@IO?^z!bE83^c`LZ|E@gp}Iu4JRi(SBATTS1ru>9P5Z-_obw>?AfYYYTN0 zHU+5MZB~$pp-mG9sqdWIXJ=6ug%QDei_I3t6l}A|!s9sycw1LKe1K&mfAhsOTdUIl zJe&BY*rJPTqLHa4B_E*cHZ$zF)q6KKs}nRizGig3q4-<87a}+)CH+#Z{zI~K_@N!M70jhT#$Y+US@Uf|l`_^3#BznoHs8L`^l!@S(!*wFkiQaHy~ zk&t)2Il|E{k{D&6uofdmrB)>zMg_AE-o3+~(+Sv4*85x>XNzxg{OkXJ{S$0yj^*`1 z+>a*{C2LC>GCF_XeS$~-SFDfXb&WEurlAR9J`A=HSrax>(O0Yqt z-9S*3f4bICPkp8{344lW-i7Xmw{K(j5!o`?SuBd_yz&Na<_SwOTuE7POqT5E?A%?Y*UDlxjEhDlK-ksQwedQr!EWV4q7czMz#yI*?UPBN02~ z+wlI03PDvpkuhk&#0X90t>6fRobVsx!QBN_3O#QPap&BCDQ+l)4eD|4$ zpI2sw8q581rXq!feNiniPa`};D8Rv}{={1aOomGqJ&wPxitKG8at78gqNP zrlgY z^!zcIDS2)oC8KMt)eZ+_NQX*La{L;nK;6|k$AeiVW6ywdP-8Is2 z1&V}mmx5~={^O~>PfS_5Ywr~NOhLm%)DPSt7K~75LEvAJ(<8aBcg*s>N@42t7F~y?Zs3tHs7 z-kfjZ%mQ1dq8@XTPpCa=?Qo+frj|c+@Vh7~zKbIhEJN_)T$UDnIhre&g~JCuhAovMRE?!sLqAZM>_M67FmT4r#)W$ZSmFhJf*>pAmi5F#%c?4ImS)rQ;@(!Td$XTrH7as=U4X`yMZCBz`<^ZT33A^+oO__UekiZ z`$n*6sN-NK_+c!dt8scyv&KH|V@bc%hCjSq`TxZjW&#B(dfaM5*t19mfq8Sw< zz*G65@@x13EBsf!;gx7?1>IOr(~XWdcO_p!X(P2b30}VpOg8?p3Y=8^V+@562Uyg$ zI^i2r6I);f2Q{=G$;#f{)}vBli53=Jym?Gl&VL-Ypl3y2ibyXk=etmX|Ii&G+vuw! zW(Yy?&`lp_8vkGofSQo3Hjy^h3+2UBH2Ra+>xZO$t47sa3 zzj?_0JvYilKDujE&{CuA^L2RUENZ9Hz5-+X40iE`tLxEuCV`LYq+RuMggd&cujdzC z6pQZ$25;KN{FZl5rJu(|g(I>)vekU3^~2v94U&6wVLFf4j|HSDl(x6iX(8>Jf0hdX zvStd&_*UHQ+D;?#Bi1R23TteLXF$MNu4zaVC~Epa@Gy{b?>SyeIs*d46H)fYg7H%x zcb`4=Pdt4E&xc5EPo0C=~-#E#|<)_K)G^UF?V|UY?v4w&3>Qe3ggCgze<)rUo-m*egYP zydm>&E(O2Jeq@8iryUWO7Ye$s>UhO$Oyz<8KiI0_c?K9+6&V*J!iA1-A(B!VjG2CA zk;NNaF|x0kY-^p{u}la?u_oF%5%!6O!?s~(8fZ{-a4kzqZW(8 zIuAM-cboX0M%aEEANLVvqx(IT+X_sw+ z4DC+Jctj#wESQoF-?L&aU!&p>DL^@dmuW0q+s40n8m5aX1S&Hr(wJ7sU7xiN&gRxb zNaW8JWz2`GEGJt;wtuFi<#F710AGmaCm!)TfY*Fn<4~pD#uW7qHJ!QxFZMeIC@L!S z94~9NSuL3ry3%gTs_x4xW+%{O_+FDZN7f~Z*zLg9t3}zV-MA9>H#W*K-tn!Y!DDld z@Tmkt_N9=oHB!D(t31tW(N(Uy&8$p%8r_Lz6$;&{enFCM4r7Z|a`UWR@F;4D6g)I@ zdiXla^{i9icJk+5(L(#P*R(z~_Cg`&!w5IsEE$_p@~!0*9#Am+$Gks=JdhjfyK19Z zq!urdMMrj7gDa8mAkE0@kk-_czOK1K+wHEmItZ{2+8Z((w*x$S9o}jOB!GK+1!sq@ zm3I$OJ1oRKZQ$f|ys*ZD&<7PHf_AN%1NDY{^-v%TqZ3)YIOM(;uVFjra1H%W74`i7 zU>-7`6V|rNe&P(yKsV=8#fUmc>itXc=mWEDZb;sEfhBbmr(l@a)#J-s0|*AuI+$nG zCX)f(6Sj_>0t&7ZWu+5Aa03B;U(&BLWxeO^WQ!qwvwJ6MncoUXhS>5tzklZ?&jmaw zwriG?#n;o%qRnsRVA)jG$!Xu3>(BN^P53KqmscqmO49YqdO zCJ!&fILMeClAeX8>=Av#D_bP~P`BgniLDJEuf1~s6U)Z@(O%`}#*BZuEf}?9OR# zFh~@}kBpt1HY3V-R}DhKg@XGhMybi+Vh_QR)*GFx>h62{Q@^5@(e&$8{T@xHR5;QY z(W7#h&uen3QyNcq#{iNQ0v*sR5U3lPGUZr`qpUAQvj781G`T@$Xtjb{Cy)iG8=zUjtwvqc~a!vHZgCAVn7oZnb2B zA|}(&Uy)~|2H>Vz%u0@W&_c;;gxAFBUsWuW8B$K{-;s9#;r=k&&WyKKONgu>TeJUg z5%%E@*gcrYqD)nPyU`f}*MQ>PWlb+Ie>8&rO5gNV8R8SMK_7RwJES@&&@nyVEL?14hCx3m&V5e}V}6EW9S@y9^V?Rm4So zi*LJ#`jO&uWU!uM%|c{*qvbDYjyubk?=)jq3*vIDQT>{~KHkBh7@zm+D~j+*;mtzZ zcs+Q3@V|@O{_);_VEk zi_w(FXN~(3$!ZQ@dh+ z730>q{x}wGsbW437)D~$IDmxOjODQ{>EI0v47}8Y6L?X6F@ld~OG98WOZ6exE*cAm z%Bkqi($W^O*84O)11jIZ29j?uNI{_)K-eDcH)3JXBb(Dv5W!vvXa1yp8HVd=O-0|B zY4<`9bbw@Wany#Ln^dKdWPYPF7n~ky<8psvQRfCZhBtbjoJ+K=Qc5DPm}an zw>YbTjXC*-?!ed+)V*}8tBB6ysYGs=%(+hmzoqSRY?rs0)8Yjtgo|Hr6T@lSp+<2K z?GKPG?JDVusM3oL-{FSOQ32y)D23kaWYKC4Bku1!ho3L`Sm7?eWp)PIz1B>Udu#dm z07bp~nvXRhojcUovX{|l-gFYL0-%pT>1+{jm!%fyRQt)ZkK?Ol$y+%@y{S76z%CC5 zcJv@iU`U{a3Fh8gPr2`9;VW3VBae>iT;jM(hZ zDF5^+p(lqeHZWE9H(kBI^NC~-?HgqSac7lR%0mjKI_de5Wb6bVhrYqPQ$B660vQdv zU&;k?jd)Q@mp$TWjya9)!1=4x6r4ig&*&`1IS2hLj&aKi^iCkrvF)$LE|Ga3pNEDTx6+P{|j4gYjM91qrESjIg+8 zXg2QSB)%W+XSTWB3!R9^VS(&K?Sw5&gsJ#(r=XnBCl}#Qz>UFXgL2-~JO@=2k9K?( zB#X<6cZ=%&f@vn5L8?KzIg;y;f|v2~I-)TCRtt8{O(`=N=4CQ3veD?(} zFGLlk8e@SP1AK6j`eWjGZ8p~;iU6hq;M!?q%c+Eb>h6M0KsvYdAl3ebGoF5BhaTn6 zcLe1Xi(%)U}F* zA)iRvRt29K6_cu-^m!nIR5vX#QaLB)Gpc_Es}W9{c(uTD`uK}KEP?>BTU@E>yQ ze?h8ANRhO*MS!jvum+jXE4LvW1thPRO#6h)f2rU&yj^|fMzDka?D;f4s`dGf@P2(V z3H5*#P{*MZN6c;CIXo=Gdm0iqW8Hz}@=iPlTuV&-1#-65>Wp>OZ4k;!{u6F#eTIuh zdC-$dZ^F~J=dP$Q+3xBj2k86slswXI9^7ryz)6uj^>t98V!5+qHM}$9euBmCw(yTd zPS*X!DSfg-AmUbGvsl*%v)xvl!09ob95KLQ(Q{6P2r#@q+rceu%L%Y##bs4}5zN;O zUoNQm1`9nT3ca=W%f)TVdB|@@ICb*lUvuuFk4h>P+CC_(KUqPT*>XzQeAYF-wRdO^B{-8Ri>nbmzLGMBU7m%- zZ|!g;glIn#@;n`be z`eT|J+Siq|sdJVbh&x+QW7j$WusH0Fwf9@0`7Fjx_V$8p z)QZB20N1QLzgV>Y&h<@s@x`YFs_J4t?oS{KJHZ|feJeUNA{_~rCZC|hBL0pdRgHW+ za5DMPNk2OdxK>Af2M6sfV%>;{d2olj4~6#M)}Ozv0QZq7-{99>f$r}IVG06!OmQ1n zOT_)XB0d;3J2t>w8v79e{%;3E5f%!iwTLJvocF*pzeE0?nfdqYGSyFYmQ!#b|7XUp@no?@vZ>}zEQN)IhF3d7)z#IV z;K%Ek-OH=1?l@BZizSQPe3XNu7$I_%f@V5YBx64Cg5 zKPbgE&d$c4#c!2ga4oltV3t-a92}tl05jxa8?(2s@3IYMM!)BE^nvJp8IqmhkX94?bI6`=vC``W8*-uTxkFQxATW$OR#XAYa4-c;%G#A4ILP7Rn~ zpOCJ6zU8=A((BGxA{-Lzt|L-b*3UmrtN*ROgiB8lkd)GCtmBc}a~|+)G?dzIxbU0k z!-wtaiwHj+m(ABPQN)I0jUkTP?krreBe6!6h6R5N@mBD;I%y}<`N-q~bMO$=va#WW z`AK_blu*IK2O2e2;B8(+KtS%`1cCo~m`RC!fotj3gYhqV-4*Ska_&aDmf*QbMe%rHU&eaNZknNM0iCX%p zG>y{;`Lkx4$y4{`VHscBKJ^{c1$VoHFx-2-zS^gj8>7YspqVZtwl-C2L7ZbScPj^d zMxw)fx&#gRMKl+f+|NsSr3Y(mcWh43Gmp;ir$G1!5W^99T(006pg@94X6(uQR@A3I zTpaIRk?CVKCJa0_j~+8M8;o*jjjmc;-_UQu&OCHX4|-pux3Mn>K}d0xr<_q5snu0b zBX*`I12JS1D$kfx^b+?mj_`U|)Wf(w$#^osYQJaq^ikPVrp@(spSS4fVa?6FpH+Vc zuC;oM9KNT5sU3OTPY9uGzgleV?7UJm0Q#G)=Rh${O~M=@zu`CcA0(nSjsYEra&Nz7 z^ChF!(#dCa_Qw)mUQY;z9JK^Mg_v`*vSPLe;;)B-r;B8-uMiOtJ;4vGwBOPKVIPNl z-^VN8>3mp%h$YNUzrV@~th=HqO?Wm_g1O%2)l<*RfgvRqw%qT%KTaz2lO5(feX>-G z^!KE@RBK_zh|8%r=9bVryHU|tDkjpv@!w~R~D;dC|AMEC-w8fQ)p zyQ!SinO(wT&g8rYj7rjc(@_ZIQqDO4YG;3G9^d(-6ki@vEVr>SKCC(?UWhZ+q(S zE-biSn(WujkjViGAqB~;ziXQQc>>tWHoRM(>uD(pc`RHCbn5uf%ngMXTtQ+i-`0SovYV3+-F*RE+)_L(I@xd%EC&r@> z=Gdd3lS~q^KzYt{QI=B;rTsZZu*u(cbB$Mc^h8IGne?gEml(=h*7e#{zQGIUPMGf3PJGU=jR7zGmhhFKcaa~_Cw_4bAeC&J2m{8lYh z))z}UUcgsZGx?GT7lI}u(Ocw)>3I%O!3R(L8G401I+Brbjm>$i9U_FAkVeswS*hM( zybOnXbBjJzVTMGO3Xup!ta*RH_S<1C39x5@yFtwDB>Enw&*OD6ubVA1(MqHE^Z1Jw zUgfe|&->gEB0O*raF<(MkSlILL1%Pc{;}^MoF5gq@b_t^o++)Q+#2dvCAp$J_s=~>Fr}pv znY^T^ytN=gR75MtEu!5Ztewc2-{LrZyQND%ve8bSaTR zkbw&>%3N7rd(+zJIT8#lrLt%?w6|7Iv1mRBz#AteH9Vh9%bgka4Lp|+z(g6jX#oO- zVq{Z8-4dUCIfV;efeIx|(vs+@IO9#83H1Rj71pVeZ5X-NyvN~CFkCY?nb4JIJgCBROaI3WbGgO^HqONR= zEwFI@^)opjLyKZ&)it*`eWKt=Y6G@Qh`$`;202p7Hm677(Hp}e&sSn?dO891yznen z?Qc@j-RTO_6{g50paaq^NnWHOS166ByilU1r(Msi4znudjEs+$oX8P$f>FsrHYKYm zxw+)>IRZm+2nSEK788oK#al*s*RVy|=@cQ#hTamD9vO0Xm!Y=pWe`^vrq*mwqL7Oy z7;!UYK1n-2+7H82+>f~40wwRNm$o(6Pk^Ky`7VI`!xoXUhZU&qck+&$C`eCn@a&Vc zr(9%w-{|-D=3(K=8x=>_!|zDf_FC*{R5GBqd9*I#Fs>!Tx}69e$QvN1Im&;A1#!5E`#9W|ft4=c`T zk^rXtFc(rzyoEOvrd3z_8K)24Pw|9YrNT3|lUdk9p`eBm# zbvpKsBiQ3iWPZgatcbB=Iyp0wxFEYQgdoB@`R!wAw+*``YB?8cskQ8i=HXKQz6-1n zly9GeTF2E^3s;RfjFyB;EaS6Ta6pT;=5zvgVpOFBJnu==xt(S*@yns~_KW{v)}hix zKvy5;^8(zxJOjR6j`=D;e*MHKyV$`QWZt@s><6a}#-pGy~ zBFD4$g6TlDYv$bYEL&a(0JX_40E^|Xh|7x|O-smNZ;E2F9O90Fv{Hkj)`_LAA8P{t z2&05GNQZ(DY;93&H1`76Z=>0>*MkG2G1HGaLUbi`ec23}Nrdv=q)Z&6Vy_N;)kDF| z4K=wB8cE}8A?+|#s$r-i%O22pitug23ZZjI`g66NyE5Sj8xJ%~+9P{e|KUXle4z#v zXa&GfE7A)AzA3zluwCHS0a`~nybNl7&n-lD1K3((^>jB~W?P?u&r|)(qcPIq^~j|Q z8v||&rsM8Zl^jw*YY)q6`(t3df!e}YdJCjDqb|;@I-%)l5i?8 zetXe8_@ZWsVv?_NOPA^QXP29k(@eev*8nk=h_kyZ1nG-DWN6Rpe7$`YqglV*`(VD3 z7@tA2^SYd>Jpc;PU4w4+M64c?@mi^16w1iuDhR*e%FD~^O602Em68{}nXJ;|MQxoJfY zR{Cl5vD_3TZxvIn_Zp|1_6aKFw!=ODJ5rJ1!Is4vjW9aJZVgm&467*rk9XWTIyLc| zub9?h29T@1Y^3~$-x@ASY)nk4A0m8q$7F`ahg({gH_mz#d0$FyC8zpJnTY1=e*n^x z+4s1f*Z8=?WzTmX{J#kAVT#niJ?#%;9jd;RDoR?lV-O_gr-bsP*@BCEwQ&Ted)qNbept-h1>)jxQD} zR<$8>oje9T_DGv&xP}w1x$u?E|3*_=DAm=PVM`ORq1JY@jQjp#VSJP|)y7aX_U!M? zrxu_Lk7)rxD17%G|4m|}IpQ>@9NFL5a562a0{6>g5a-5f2EL~9$+O!}Svqu%yM|2> z#+t53Zp{o*fzEaQ1@8{#kT1 z$cCIJKpmP$2dNOfKU@1E3f&vaY<>mE`ukHLw(6&zaOG-Zr$4MJ+d95p{P*uhcU_39 z<^0(e_`l;zVh~GUbji^E_ce@`!7}->0X<_^r+`k z{A1>Gh<|rXwNihoDr0RNXm5l&f!Q4evH+t)6oDoe(utgpBzED#n8xeC#AZV>@D0`i zT;6Cr`6^uX9drM}yz~gtdIl>`43eGS=Z+@11DDOPj9}XvJ z7F;TYvo{AFjxA{#2z@baSe9JBuRK!Us{iTL`E_FEo8dXdg8cZ)N9M4>15ZLCTnVe9 z?+CoxD^j>4GftiHw#Hou3+A@joE7Oq)uz6^|wwjkNQ&U0-0=08N+h%3a#@w#oY@*%O4u`(fXSupIg zTA8E6_wD{zc;p#VNd+q!nWT-CL23GnvbS`Y8+C6ecvq+SSM)UL{%! z+4V@?&5ql)HZbZmIHJF5kvF>B>TA>6wF4)uP$#~^zJYDkN3vhL!|ELV-MrwdPxn*o zhwq^Mq?NH8Kv+f(Qtz$Ds|!oc`T{!*rv||4683KX92$y=!J0E`960noeFh zh{ztK=e#3%=4vVZ??aoR!A@xpvO-K|cqG9yQq6t9=spFLD#@t9cUWvstS zePnK>dcx0`dAE;O=E&~J!iMFCKeoRm25_mXJ>vVtt}ySE-^t6_7!B*xLxdZj3$U;6 z6hvEbbHc<;CVOvfPd<^*ytMef^*~8WC{y&KH6R-{I^QgB^5>5XP`48pPm~hdNUX6+ zL>_*0!Ct_F?Qci8pJLA4MOsOcs@UJ43ZFy17{gX1&I2!w!HT}X()Iy1D8Y_FSADFS zJegXu+0MBUm+xR)uDWTi5JpONl2Qwq%U6%}^UwI-^!ojKh^jqGu~I5B$FlFU`MkMi zj*Td>(uB?46Fih&GIzYBDG>drWk;D_+@w1(h;A55%L5_!WlK6)1lp};O)ksTW}vqO zC->W=(vtZ$!ee9bYu@&xS&ijj-*T+5Zgpi3z>4IM>ys?ZHWvs>!OD>D zM+1cPx^lgi6aMHxWYKOYO(geZkbNVy92`*+k77rP^X5QhBoy@Md9v;nT2l{ zLU8FQ#XJiE55QU1D>@20#AyVh(On;!FYVHyFB%GXoWqSS18q6V0@n+Y%$Y%Oy zsL^odTxe*Lfr{j=kQHiq@e>4i`?eyAywB?rt(*w;+$a9eutAADX-?k=v*x(WbKXpp zKv>KK3-XZ!NWNQ1`w;d)KmY%WoWl=ZE?RubZgz==Wipj!uZCNO{>v@@cwAA!>`57+%0U%Hfu7-xh?#yHV^M666O9&I z?@D;s`$k{;l7 zNyf7DrwM=dCpim)xX(;WSp=!y*H5BC&c*c2gQ3q!1YU1oX!t4XkMs>d`@DoK3>oiy zW`VV!Lc0?Be=%zxG*j9K8-UwX0t}wM5JLNi3bB$MJy!`Cr2j8yGz@GGG8D;G1sCO! zq*i?|_;0W_0{olA*9q6RG_8um!ThwOv;|JL`9`9G0#PK6H<|E&)S1H9-3b9JM< ztL}dTmdkqf zZN2rJD6F8Ud^TU$KsC2^zJ_A5*0O>>S%2T`}^=kQpfZ5xf9Cd|*I1UCZ9RBqAV=hp{J)Mvro1a)lM1 z=~=+#04o^b~Bm9H99OP-Z|};JU6`C7Qa8sT#vXV6 zaaXH99ZXjfMs!ukyrDtE4Q3rd;Yv&(n?4yYa05 zkF>Lns-sJ@esG8279=4^a3{FC1$TFMcb5bS1eXiJg1b8ecXxMpm-|)T>FMd2o}M-9 zn?LJT0cUYf)j9S2_TG=N`hvRc;vi1#p@u08x~Gh}NjDaijP6Ul-CA_Ea9G#K2+Hp6 zE`U~k0}UIEt0e9}1RU)QC20$k>$Y;q$3uIizKs=ndAxb(u5~@pq<@y40|5nXO050m zv9XwX;fwWk{Y(ZR)DI|V3+xOgpc4=vy1BX0kE{Lq5};Qg`R0&tX05_NI60ZB4GU=Q z7Xu(5!{6cPS?oaIpFI#843xvMbHrovDK2TZ;Z)Vs)cD?4>32eoPU-o*z`((U@pMzE zq?4!@%6g)vvfB(+uD1J;0K1cT%7EN9i^pU=b)exe0?>&%I4o=n==1=z;DiJHp2TVc zY{2J5Ezf~F)Xm&#omT9H%x(uH_U1E^ZLd6^-{w40QH-&lvhL*u!MZb-x2%N^ILT0PJ`C>C-JZN@WW%*Db2@!;Xg;81@-NZjV~GfT&_RyhrRYH{8HPMMs9-n&b9sV{I;OJhpDXmIPA~>L5zfya4h}UHzAhA)Gp?J;(aENo*NS_Mg>~41VVvw zVy>t%?amftsnk}X0W4(HvkZRE!-`Q{PalwBM3(vbn7-*`^zT^!H&))v#>9as=d=B7 z5Yb*w2G^`;f7p(Th4R?jxdH7gj&gNWKn!NEbfM?2HC^7MTKlC}naN?s$Dc-9n@FXO z3`!C6lZ5(`v#a7jJwdpmqF|S1a{zg0gAcM6`vvAL4`j?4jNwYO$>qmbpqQL1U|^#T z)!i%?U}WVbQYBeiyk}?d(!Q3%mqt5;mrlF#K#uWn!$RS^zgMHk=ddaA#~%qWnnA2n zpik(7C6DY|W6D_nO5n-t9N1n^K#4nVx7sqGMw{!INUPDMR@&OynkVhOlWE%vyyQK| zP?g#y(P}8lfq@;`A3Np4Gw^G)zx<4V(tN3_t?e01pm{JB0M6`qd!SFSt9Dfwbi>l> zG+_Ztjh`6~Cj+CDizT6zw>M8qON#)eEbD=Y5R>d1$wd-Hpyc$+rX(UF;9khv zb^`fJ$ICU?F>@%!!o1ed51PC#juCWy2M#m5X{QeB6c9qz1@Ob}Bhc2B zJgf>&2}BqU?cMa`qT5Wdc^y1rE_-2xj#~{a1Pgp>vy`QI#zM_dZUIUDA}{69T%k89 z8Sl!uwlljEYgV;L<#V`@Z18j@w$wXmXs{nof z^SSu#X&y}~Qd97snV$;ZymN(3J1>IUx*F&on`O;N>_3Tg(roeJoSOK6(URs#c7h}M_|HAD~*XM^uW zthpKxRf^OM7k_-(UoiX(x{j^kuks&=M?##<{r%CMSO@6}apzwB^pp|nTHB1`o@PsK zieIhm5u#7>C3E^fG6ej1bEJ)}hkpumLcfoE`YHPL2Sahx6g5ww*_0(8OsaO4WDpRV znMbpCb~ZKx=6~3Zcb=2HkBE-$UGoEvv^Pj4Q1`SqOqwWiar(XT;KGt8xKUD4iinAA zK3wfdHC|E_*^~qaQm9BUPP>k?tEL`*uaQQcM`RAcAeSMSP=1jsi$H zsRD3<9V!?zB&hC5*7BWed}I(1dW9XEN@lF7)fVi?3nSea$*G4b1K%MVpfhZ^I2ZI z<0^|$_uKUCoh^5;ZcuY>vFQoF^u)J93Nz%SFp<5LFM`soPWxhTL1T^yw5P!;xO5{e1FoL_2?p zEJD9+kL_s>46JQXaG<}&Oz*n`ZHmggE1Ie3ha~#45A9b~Dla9qDFnMPuU4dtP;Ew| z3_CCj&>a3=z)(w8JZm;UIyBzkU=-`ajBtu&&IStptXY`=In!#tAr#9ln8Iw5JM=6e zL|1q^md1gm)9kwAIyQ(uTcQzUF_nLTX7M(b|Bd7g=rVy$E77cAAgBd(u{U+jeR>82 zQn%Ov1#gUIa2Fr_?47)XIsy=VqlIw2?TWsg-I%O8E3Aa9s7s2C&rRx$gZ&sz=Fr_f z@vR|M1O=+iWWr*r3z6Cqi04P4^%AQ%QhuWRhKB@c{Sf7pf4LcQcdRY@z0(OsrkIe8 ze4S>y=K`{79o-eW{Wra&SDU9TaoRNg?|u@^{uC%uNh-fyl`<7guqb#PDpT=?nu1j} zNz0>~B4_r_ZB4cCgH%tNERdVZ3{==uSs@xyO3AAQ>(;wCXYM$Z4j6iZZRXfj9dBYF z$zp$`4^D&4vir{2ei>hJn))@T6~yCrjG|yx?1%YOSo0x$T=wH`#+4^~+O;s^*6;d8 zXTEIvzS^seeopdx5C2J6UoOhCQ;G#HRaoh;v43#B@9FJek`wlvN=YMR-p`6MLptQU zn=rlTj?k2b`5L`o)e);4Gczr0w}Tf3j*7yssW(L|5YMCx-p(Kf2|Ooz$un_jl@SGW zCMv7AUFo0SBtK=cv2#T~tSe-Q#lu61-~FyyC~h1CB3C5pj2+xm?kJ{MU zQhrDIg05IpI%CB-?3KINp~tgnBtRO|HlTiLT!XbK4a6h?gh}xP!A`rwDGlBa%&H~o z-&dC6j=VFPP=IE_@_f3juFtUb#KNl@M5KZ?fHbK=3hXz>upU#pcJ8XjLMXPSXmWAn z>Q@4$d^5FI+I1^S(>TR7$mlr;S+d{_iKAj^$X6}<5fkFmc(AZAdxQW9#^u+l%JjH6 zT5_kpAZJ7UUyUjvI-RMo%x8v8p*IIxUh37}6tFQ5kO{>zlT`LPT}76FO8Wd)jHbcX z?9{Svvg&Ccs;j{ko9tcsair`}Vh3+~xdV9E&{~rdW37y*L**{#=N6OI0KM&$jOz=B zEu6t(;ECg$D(!yQ5#?@<+{oqrW+jGrij%;n*eMo_Z~IE4A7zcU*tI^9PX&g6Se7<4 zQ?=;wTwPF#e~aqAb3w5(Xe?*dsI|NhUuNAM4uP}wrlJ^Nb%|+4*@rG~N4vT^iso+} zH;P^2K2uY%IXJ@IpO#WTS0DC&xmC#!dng#jj~WS^%|g_=NoV)w8_j5=Rq3WA zvrc6rA!?0+`{Bwb2h|9IO)7LMERL zcbQCH#5snSu-aii71eBDG6C;6MwU$xmqpy1ePrB}ir}?YFAhLL_nQ*-qut>)Ky0q7 z5p#J34sW+V#;Wc}hAX1h1ijIr+{kc!lO<66G@Xg{YXk^F4>~R#Flxu@kLOLH7!PB` zg72<2Lu4}<6dr}Qi2Eb=QT)!sYl&%<2{|CP1I(M<53Cm=(UidJKPZaWlKv|o=Bbw9 zR*yY{O?{iQ0#k#%*mNU^tWUeRaM44LqV|)F-g^JAK|IKl;AI@jZj7 zS|EA&uGCaCl!8ab^eLo#S1B;Vg{{TZC~g|;_b`-X1$F%5gpnU}7CureZm`q46>bbj zF-gasxo2bGJf10={XSjmHzUa4&8GSlsj8aq@}NWpE@goxk`Uxs_+8_xERTOltFysM ziHiQ)3P+7YVQZ$r#53F-~5f2_#wow4w@$XGJ+wiA-kHggos%s^S zd8~nyhXeyEgGfLt(Am`m$?bLuU0SEx>>9Q7fa$f(#Lum@^S#Lj;4%&C^5Phy2~Y^v zz&E{D{J8x)rpvDVrFR6nKaKMNKFfEd(J5@{aB8r#Eaeuc zQVHwa)|dI+*ok8D_;W;_hGxB5-fQn>s_S_`U3t2SvTS^=P^+roRL?17OURfj61CLi8-*b|?eOm<;1$rGQJNf<=|4SW!L=VMYaT7kNlCHf; zMFh&te;*x15CBy9(Y8BcPhZ9z^q5_`B{axHl-yllikm8bek#egu5ES0<)$@0d~~L@ zOZ4?-ttX?5EI;47Wr@6dG}Ti!3gqHs%z2B25!EM8_Hb)5XUi+hB<%CUxNfw)nuw#{ z?@dDMtKQqpN^cj-UzyH11CWl+;)@Vp&29xmjNi>sk7LOdiK|s8>V@Mv8cJ;J45>Y* zgUt6qe@V7xz|THyUN9B|u?)y-{>=n)nxAy)uofFQH|jz4mpPv>cUeNGe<*vHfHtqQ z1t=iUZ23)>9(1!zExTA5bTh8 zQR}Javc{i6gjBt}5ur+7el z<6AY7nEog^H=tEfCCg_+A`$Q?A+DIf%j zfV^d&veGFTHXi~K>>C;y&gs^u=Zs$Uf)uhVhg5by~!v>4KiQM^lv5^^i|t1_`sGZU}DmlmmMSufLpNIF1a1V zpz6R~I*MLL_S;+`9g}e09`RdK6 zGY8up(JqWezTjo2J#C5*)6};29zRVXPf^3#faXcrclj-DTcqBU1H0R$T37$uqeN8G zd4IikU)r<3)(O}9rKM+I(8(+q`Uhlg>Ozj$#- zOM->5*DRT&B4FT;09{Q(4!7csn;lm&86z75Qn{0#r1>-n(p%NXljoJ;1eyJqU8992 zNgtUEZZTfI#G*cB<%d-Q4%=n_ganmD7``9K+a93lCvri47VVlrT7-iiatt(xZfC1I z;*Uv8@~V4^o!=hrzoBStdN}p+*F8}rbJ?qvN{u`Ufq83E4@T0z5S#C@=UB;Bk{Q+iZX56wGcstsRqBnJs zD}D*>`POj7-xybqzKG&K{UqtGg6ceG5i*xtx=6EhXl5MD0LF&O_oJ!Hnk4BxQ;}6**R@w=Va{!)170S9K4vcvCL?DYe=xS@kI@ zlF;9kwO?|)8n5r`1Mg+r*Qr|z*iZ{2s$!=`eF}8VY&j{JU78pCc$K_~W{%L4H~QL( z$xqO+gC@se&Kc1~hl-=|8azfBfNxDT%vn~4L^%)mKTWSb5&t!KotL8XFZe1JkTg%F zsRq`coMeh$dLjuODN_gM61L=#85VAE*$Kgc9z!DK|LWRUA6SW9+@jZoJ?U+%f8~Y< zFcN@&g*i=(Kt=j*Z4UnWq)*83o%~PQ5$XSONcO-L?Hw!369)hNnBO8BgCJgx&0g#s z+uSc8n_W+YBqZSbCj^t|b+QTy`XDGMDYw=;gBk$I_wmWe#^&ZvocsU4MY?6e)PGzN z6%~xSe&0Wct^<14B0WR^3JC)n&j$-YwSh{)_pPDf^SBEkA))Wfc(yR$!ibH`Qpgee zrBaMjzBgT{^hY>eQd&CE;^D+8uPb_qM6;3YRR_3Ao%hEfOifKaU=+$Y+A9Ew`R4XO zy!*`|wq~7m1kk2u$Y-!v!q*_rY_6{luCTE1ofV)qze4<}eTKglqz{FII$C!2ZLS#i z#eq*w6=^*XE5uG3o6Ogo+5z)Q>a?d~b3dq}%pbMs;gi?00rikO5s?W0@$ z-d6OY=RZRtTth?5%DZ!H9yBiqv#>>g2??KSn!D4wk~Dmq$Cb2qC?P92S@c72K#)4K zB9xvkS3(DrYu!)tF1M{W1VmaK0zq5}+f9y`69CEzYgqc<0HU@{w_n~tj(8Wm-l4kX z_NV*jqwcMMd=}k6!Tju$Hf9b`J-)lR3_%*Qv$XwC|H#Y<1Hp%Te(V5$$X0~H4QIs} z^`U4`G*rBMZyf(NTGC&~iH5bNv~MVbXT8r?oN~a0)m$4cv9|a_vFkQu%3qKuE3#f@ zJ+E1LyaE;`v;%w(^LlEtI*6L+Qk37FbDEoF;PXfYZzjiFVol3oWAat=hIXcN%i?dC zh*tMSGMpm22@p=3PyB@8;pJuWyf*(+K?aO4p~b~C#C+cHK(!7q*g)y|++g!P9kak_ z*ipg(W}~681HFI5MhNmCK2ZSNJQM;VS{Yec58YH5z#X$b2Y{pgr?>xrjUa$e*5ls* zRH{X$I~2vh@aMha@mf1Sm+NsjI+>XNVnucJ7A7_jz_#-__x5TtnMD1-cRaX15C%zr z=jCQj59o5=kK}%EIE+>@O~DH%o5sGQZgFV={PD|=Uz$)Y&jD%*NGmZ%M@Qk&(E37< zaPQr7Bxz+s{*TNE1yYWv>$LP}y9zzK6W&zlX><$i;sZNDd|1sEqz~s+u=1+!e}V#5IVQb zGm5ueUZh?wa=^amN%qu88|)N$jF@`{ijtVJT4Qm(aU2N>#EBT zS(J=cN6lkY&6l$ft+{A75ae!edhK#C5F89LjWYFYexe!z+!pz1T2P&Zj@MuxL4@;%>_vXV3rq0gg zLx+VKF;@u|I655c_`hZDLc~G}OSQ4Ju1CLiaUXV#M5fOg zO;^$%FL9f&!GE(FX;*ycOoaya>26L!06Gl?>_|?Xpt5*>$v|)c$Tt!wxSMV(@KHh- z9nF>=K>MYBAq{u|=#Y1h&|nMTyZAw)CdaiSuGv_SR_SqN%7-acyhJXY_~Xj^^n5%p69DCW?Y?X zUUMYBH0B7u|7a zciKj*3g5?2Jv{vX21iF4`YRM`&Vk@XexHUfz6YM9gm(G__MejPcW`JybGY$hR20j3 zb>L;+&~^@QH^0$dW%M7POoDg7+`}+nVM^>XGYm_~p)l2YX?)!H$jzmN)`_wNp{X$X z-X6-0!wHXHCKjV@IZ6sMz7(!Z8|oTC0`%!Y;q(T1r{(Hycz(~UkVJ5={hx;+NZCZ4 z5v&ZF$`KG4fJTuteCT3lb%U-f_P_ULy^vmuY;8uGW-7g^-MBjaaJkE^R50Cac{&Jh z#XGO&bdsN5SY?QK!U~7;c(rk#%a|c{DO@ubJ8It$EY5WIlGN+3{~~_Q_*Jy50s8v9 z((Ji}!Ix;^(i~j*eR%|dB`l{og_Q7H+S?6}I9pQzLk0;yN0JAjJ+mkNsCl!jBC zD!&w4F{bim1jlf?AO4yRs6Qg(0|VJUF(G7%&vms3alUfbHg|s+cgg@nN659~3-|G< zfD5%p5b|@O%2XuK3-jbNWm{XDwn=8NFJi8$T+t)oyV%+*dY$-#lHlNIRBE}kEpR76=_ zb%43N#abv7m|YhH6wZv%5MHo8md5^X5Ry7mndHptxURhO0h^3%yZZ%(_I?49{DP&; z@>e`{<>F5T(&&nN6ai&kV5aA|6c6&cLz4dmI+|9Miv}po z><@8_Sw(GlU1JFrZHb-AWKnrNRFMI25AG5proUA^mJvwxf$ZmIrSDh6%T^f+vf8bg z24jG~DDlF=M)thfinG9kapvU!uqp;>D(!of^A`Ajq(*nwf2T%~Yb~h_V=J((f#3)v zcVevGVltP<)+N;Jxpn%3BE6?EMOLGt0sQWr>bZqOYaNZ7b`5DQ5%Jy+v>yY_R8&S= ze&Tz%wRL2chTS*MH8@Mw8gUl8GT+ivUj+sxJeEb)58Io#mi{|83Kj%e_$V+^e`kmI z&0u^{4t#E_3=l8qy^EIucXZ4!OE%gn0k~g@^wqSQj4aAPaRex?{RSS3`W6-ufVpfP z=nA}}FF~3VY|%toIvoF{l%Cag8P6Cb6UYrywL|$ndOA5zfDc+?#Sf4>|FI#@p^{!9 zB7U5XJA7^NxEcq%YPNT@2{~J-Y*vgmi>$BjQ6MkjNDE(;JQu1=4mv2&+&`j$EeNfM zg$Ak)pMOiyfIlL;oi1|@`C5Ju_SIi$a3JHb`J?#1D-fihU4;J!ezcl?Ho<$`^94ip z-D0j)=JbQ4jo;;${3@HNcXY{`zP%N9XE(=R>xcS&e_4tzoh{d84@$5Z@`ipV zt!BJL!f&|n9Puo?2m_EyDjn7AM|phd1cl`5-%Rlkvd}<1AA#raa9pK?FU*Z)X%Mf* za!yv~ZjDBEuENx;qy<+XZ+OI-E(N6u26yz}a!n?$9SgG#xLXDqiDAS@KXE1$g2#kn zHmtr3>^cp;dY4rnr{xHjATUj>1-D$0QdX{9 z8s#EsUoH4~C~Ve2Lo&?*oYpHo(MYl{%O?p8d(>sc>P|3Hq6RuWns}#2XzlL;P*@x) zE!yErQ|H9SM$Zt}scp3qlL+UhMMpemQr`H+?(sH_;ePza*}ytbMQais87`<1&#H%E zA|rL-2q>*Bho(MubxiP9THjihUHXN}@==BHtJ-!y7Y^Nrv7$LHMmuJ}z!W=zoCp}g z)zQu_1Zgotmt170s)kWt1A#cL$Ccq41Ln}G7+ED6)2_)6(%~U;=P{9yT>$u5-vzqR z_etJUWT!%wR#9`Wc^$p_TfrOm82>XiDrA^II*pnB0ePe`%K{HkGxzrP3RW zSeCSKG$6x>`Aj1c)@SbCb{G3`2UG^>%)84fp;0s&PF0cHkDXa9^H<}gsYd^a)B7nTUf9OgD_@nF#5P}rvY&K$G%c|Zp*>isN|QIk(Mr4DdUu7b zzJ!!|$`SIB?8)gRR!V|iArKvN)Fjqt0`(6BmX!isjhFtoG8Vwi6;~MfK&sr`V$FkT<8)-Jtr?(Ds|zY z?u&y85A&r)@=!Tty03*>IsZ(F-ndk_YXS~K1Q#I2828+Quo7sV6QDXM$p+%-pW?L1 zPKt}`QX?IB1r0+~A?Dhy6)I820*q)tMu6H4F#ESTL1s*`2j&j0`4WPy4GV_Y@&C%HW!DdV@4FR=uCzL$Kw zVs3^|G%;Z7h)77cCgzvZ+G*YVA1TqitIwVs$;Jzr%QK-G8 zDeh}U49QSh9tp|FNq@v9P5!7~C^zS$OEp6glVAqV1K`40tcqdXGDqqF&64yqyhr56 z`~!!~t@xN3_ueWTJR7g!3sjjQDaRT_PkVk+pCN3Wcz<@q3riYnbki-=a$*j;-XLie{wH|#a-D{dDeF7LM^ z@kqv%b07f5K^9_%#{+!~tt7a+rzgbjLCgW;+rV|ueX|-(S$`TNcsp06#Lu4i>C0oW zTA}*Gidnd+5WfvrZEca*D4+gEv)k>pLE8vP=_(DJ;)-6IK872t#8#b8((C@X3+b4m zNO|~jopR{AryCM>?60QCdp4T0mi`OU>{0`bT)>I}$wk9SCCIM*VP&OHfv9TQ) z50XML!BNmoQbP@s3_HzUcV#~kkS;?*Ll?LT(nKg=fNY^%OcclWIN7Y4044G=(eg>9 z`ckIXh3u1`-rFiJAV$JuIW^KOgwCj`qv*j*9lKzmweM68F}~?9T{eJ)%#-W*zW)Oj zq7D<;9cEysNBHA!#k^7ELg3%9-f5#{!KBJJ7EiCF`{buf#8Q_u)7XI+{#4{n;B<)$ z#E~gDK#fPkaw2%*Y`HVFi|aZ%Eb@(XFpl8x|BVbS$V4jk7bga^*PluW>hd!l=3&QyF*O`Gi{3wyQe` zpNc4{oB%M0C&ITGN3&B8oAa;?(0vvUF-!vfnSf^_0Qc~NJ;^lxgicNWtr|%U1gK5Y}KyOhx)3fXygS~`E*uUaCq<9j2 z{}I3?-xvHD{a;X?4!Hg|1J6P;V5pW#$UpX_|6=y}x5JM>{uTq%`{;j#dG6l;mLM>k zyAsUDe}#D@01%{SDD*G!)j0%n78t^C5JOQ$2Kn}M07`N~0s?Pf{Ct*uK@}&^6!Co%8SDpni}f` zj9RO?(Uu^BmQ+0;kKpFwLJ(+s+y`{>(Lwn+*iyvO2GW>+ARQ>71jdqSV%DO&wLLK8=&X_n+^dj?EsrzqE?+vbsnSf+xOjeBRoM3ckm! zsIm(1L^*NiY>$8X{^KpfoI*hWU6JOj;NXz4HvH#+ShJWf&ALYF)y%2zP2malJ$*f0 zcm3`5_6|W2g^d(hLb#Qo(7c^6;2dw-B{nyMJ53g#DUP8EhpMAf(A~S1uC?+N9h+7B za>{30vX87dFU?C^om=+zdhW4hg&&T6o)yYXy=f^a3)a^o(i*^@k#D1lJA0T;bc6qH z8~R-~C+O5n(2}Y1#PFW}y+&DeN)-9wTw!>aeAkB0G1t?oXH;}4VU}Qqlw;bv&N%nh~u^~pE0H#9?c6ddSl-DdkY zh|-pO_HGOI+04e=-f#rsH7YL4vn%@9!+HR9!1dlvL!t zBFBV|WwRU;#lLCZRW&!)iYfSaLi9kuS_U<=sFGbq#918)n1rMll`l#me|k4&8K~TS z>(5dKsL?DC8f3rwkr3_gOh|=Py1Aq;Se0D{TMKwXn`>T%;eL&b49v(NqNSz%Qf(^h z?&)a=ROj{qK?$7n#XlK~+*3c!14219ihSRf$7dvSAk9ttJ_+MbHWnqyemb;I(a_Qo zA;{+zAF!kATUsJQK|#@~SG>v2&bD`R+k^B8LBs~y8zJx^2$=kO1)|=%ev4e>&A0pc zIUmmhQ0Uv6T0wPGVxI6qr2@n2{Rzb@AYuotq^ztg9Vw!q{q+$E6&2MZP`+{UxkJik z(rLWLLg`gz`R`t$A5R5<{oLSWp_&rp4&>zhfkG8VeW-j9d7rFZ4yQL%JOC+yimO{$ z;GDIh$>}8j`gl>i_lt);so0Hz%p2Q9FBr)PgkpeOYGD-O*i+s}JEv za`N@9_a_F*lXy3!XoE^u<^Z<4Dn<%()G?T&1U)<`g7F?8pY0F&|;ZN zJ-Sp9)6PkgD+#oC(rEJCfIu)bzhIEq$hw<136Gdx=qfShPTLOaPb>9|MduK(q}{V! zPf~gTP%6**o}nbSkw*hi6*?=9?XmELQmCg&!oZF`H((9c^}wLg*;~63nX8!nWs~Pv z(t^h7Mj7(vg~B{Td-%VSq5tG2ayY#tdp;jhXk4CJr&{*I3Jr#%S@5c((&m4RKsZLX zx75VsP8I;qn5~>I1k1tu5(o93;4*KtexO#P{purLwlR7j`b)hH`Iim%Q4u2x+wM*D zt$9cN9C_lM8i(_&vFkZo5`dJlN@`r~07;T+!rb%)kw)=t~Ap=4Goq(D4*N>>^2 zBOD7%3Qm72T;OcTLE9p0NhC!^Td0;64WEPS<1#PwF*r^(TqWsD-u1(uHIx0%n(4Zy zxw3G!F{j?-ORwo7@sxq{?*5%my7+~+*O;vk;4}E_F-}6W`---a{awA)1>yJ|;l=e_ zx!zX%`>VIZ(9dGGDE7-sXgLKQ9Y)r6v9T<_PmJc7v-yrapo!Ozem4;fHTcWlAUk9X z!+2gr?yS2VzpEnxht!L{6j`YtnaE&cs;$5^boIu9Y>TOzL$QjbwPZmwL}&tTcpG5hjavfA--8K*63X~J{V>A-VgS>ErX(u z+_K@9Du;xGWQT^~p2wCxZta=(Dw+75yubGQmqL!7p`l@%Xh^Shl$`;Ls!b1hfbL%y@)Oh^2eLAcJQd$2Owoa-`F_Vac6KN z&8}_Z8^7cPrF`aaIUqgLDU}pN4h8n}^t~V)tZ;*N*As#y^)E|TyCdC{An!mR%Vf_o z?c*IQ-qBq79!>VgmAa~`xFI0r5sFT;Cbm3=h`Se`T$7aapdN^GOr+E5#f5@batMv< zl~VsB(+AA8M!-G%p!!{pyd)O}1JFYN-oy{QycsM#xiB9r{92^cl@t{#yzXqrg>YN< zC7S~>ttNFpoqn0=CTy;I$^3e;R_j8<({i}ET1w+hzsSvHw9`xa_Db%KqpF^PJ{@86 z8-D_V$>hTgH+@ceD_OiIH}e75)^}-7o@Dn1mkt~@W)f$srB00pN6qLtX?@RXOD^1E zPw{q+KBF_$PtNI}-t8oGzQPsZ5jnHp540T^@@*eeHL${EOG@U+`O~fpQ84qde3Brj zBA;<2ngskAeopGNzh9znzYS&KXP;aKGDJn|j}`W^Qwzq${B6zkyXdRKUoo>YqwivW74(g{EUGCpe@oVoz>lq>svpqXGqs&46 zTInFPYG3p3GeuNqcF;jRRUy6&b#ZAr&Isf zb82*}{+_khMB1wd+xMY7BZ|N94%3KS_88bsSrBmdl$`{=Nl7j988nuShHs(s@ z9!W&`!-ec@#DNb?54oI-&ih{)2Z8z5=;;T9^*9{O(#2Y~V}(CNL`C)d@!X*$0p4uz zM%?99CcvkLQUKDlGB$=bF)^|1{N7GOM=yz0waHIY`F(FiuCSz0UJtP+`=?lTco=DE zXXBh8R1lewVg&SA_A(IJYmp%VDL_>`2>wIqH>K; z9Fm7^_IOYh6@~4z(DlKLNsWczh8+03sqrUqh->3kXdQr^Vn8BSh`Dv`rdv!h1P@{X zD6xaLvpSv0TxGmtXk0FYvOX&&Y&X%RR_{M>q00*yHU8kIra&m>pu$-jTG0HJQ9i4> zn|@FbDjDw|SgSHB1R+n=#Sh#BjDd9{>qYCL{<&2}dk?xtM^6ozUK1?iQ=)>ps|(Ous0*jS^%e}P?^b!$lxV=Z-za3*IPDtj9}(l%y#|WF zvYL+{*tv)SWlcp*)w&G-3=_LTKOG${_Hb+bS28a*3jJoV?(1o7U9BQ_(#%c#U8b@Q zFvptwRIXJ+=K4N0}>mt*uak+fa z-Ny%hAVf!vA<=ptDA#vPyDM=*O%}+LG&D6C-dcxnWy@4Tez-ZBgIl1)+>#4WevF79 z14Agf16%o|%}Jqd?z0vs!&$(`*;O4;NIvt@=(P6}hy_?a)hH6*!=_Qy2O0-HlcGRs zvc@WdKdp_9M$36$_=`tg=t&yUiBS=x(aV;6e4sRK(5km3I8(xfk%Bq`lw5#o2{3Sv z|56~c?fQvCo4R;zSa+ksYiT%ki_&5~9&jQdqu4h;97d0=7<=d`TJhIQBL8*V{k3i=bFOea&@RX8>6}e2lJV>BzDFc>Kr?D zl@aXH4s)vMV5Sont0srt^hn#+vv<~b)hl^68IV@=I!>|DPfL#;SI_7D&>A`*jKL)u zR{Nhu@t>8rqasZn)2*#_wJ34;f ze2lD>$Wm6GRVyj)Zt%Uk^1Wxo#08-uO%z$UT%RO@Ak>(w3F56D+&jCQsf{}B&a2>D zp&de7H3(3M_(6#R1b8&OL zqrC(EsO9|y^(fJ%tEEuyDu@-W=0vMOLGxe?w$;l>M{fdulag+UH$!l>v(^Tjk zV;pe%9x-AjzB{p3ao6MS1p#ey>9xc4d>=xJSK~+cCFeiI3bR$Uwma>iCJDV>^pV_O z4lxLhzF`Z2fP!5IR_a}wSfa-)L74;(T)cg|sY&M!6frLxnk7Pj3)Mgyt&Y({QDu>YW6`kWFyOe|r zSMV}>#Kn7B6%VO;cj?MEbBqTs z(+))2i4x49-s#Ic(n5kX7M3@N(z|FK3Efv;y1VgL!sxmKvW^rLPOrb=k#P+!?U`FI z8^=u!fUG+L@rPaKmE0y357pNy27gn9hPN5`==3Jw6#G7V!?&d@kx-^64QqU@J8*XP zLOJ~AKk{_{Lh{d6 zZjAXvU$Bkv@$PI4w?_)muDQ(aeRScpo%RzQ;7^)%l4vU<%2@d_SI_x~On^4Xk6n|L zwD58UTE+w~m-~sgXU8)dy&^(CwX--8}pU$|bbC(?X|$+i-B3eu5y5k%)M}82I0_06-ODq5M6|Xkl3<&6#+Nj*+7Z zN0_QAPNHgAVeK=>=n4_)=QeKw%Xzwz-}}O_zYIN++RpJ`kECaKrW4tmGm02c2qc)p zy}7(U(-geOEMUDHZa2B%)>iUAe$5*XJE35-m?nJwrLB;if_cdbs}n{q>etINl0KeP1CrL!?x6EczBO$OH1`3~)=)w;8O z#j@Pd=*Wbf2%4|5y;ESm?{*}?hfseNDCp8^6d-NiOoXr_G;n4i0dOg4;s*6zLyS~s z`4q5rd7a?-#jVEU`WQA1aIT{9Kl?{uQaTDv)cNA$<1Zuou#mTd0@B==@=rFWj&VQesFT zNl>o>gW@L=CsS^0W@bnFx%na1X3s6N;5%lA*zhNkh2wY>eom;8 zSxJLg-2JGu-d)XEHq0?g6#9VJq60|7Xw3I_OxX-if>rGjTq2oFw8IK#qBw~FM3PrW zn~!P)fKjRNVd{?#9yKa64Ukm3l(QiD?Naot{)do0N zsyn<$70{w_49P43VTul~ zOK^9G1W#}e?gS^e6ChY{*Wm6F+}+*X-Mt&@SL}VxIQNcs_jzyBFS>!Qu2H(?n%}CM zo{hq+s~2In`=y%sjGT7dd0UIt0;*BY88`Lp`Yc?;o)*k3u7vn!{i7zKxYpOs@{mip zwTf2RtL2v17mSV@A>-LRkoZ5x;E%9$?<5Tfll5~m)c z+wRKMf~`M;mY%gfWyRa)AXmS5l@QQ4*AWO_uQUOL>f z(mf3lle{do6wF9fBf4n5%R10VP!D+>(a6c6Ka|2$fpdyaY1Y=z;97Pi2o z?Ehmc(~cvWGllt}8H-&z=FeIz%kB0flMMil#R--MhiA(?%GNNFhu&}0m?#36>#5n3 zrcNUI?9VjT@6Y#FxA_N+ITf^9vnpDcwasAYV&BU~kARs8<&^KI2YZ7d0^W_E8D%F; zl%|NlKN~IXW6AQwFKd~&F3CSl9&N`uFdnFlQtVPiOP!u1!o9MBQ294|N*(k#jxMD) zTKI67Vs+O~Hr?%{>qp;eGLPUbPGm4=hNgT{x&`^YYT&i1G&M+O3uM;{K?Dl zA?(XNQa>YeF!0c93IEDM^GHh&>T)&EZm08FULT3JQ zx7b2WPep>b$N73luE3FL=|t03jnShDhPaa3$I!MsUUFS|VfhRP zYI*c6U0!E$H*73)tZ`8vUp#LaN72@d9~r_2@)@MPZ4RfR*R<56)EU9hgQ4Em#C9?q zVlQw1E#T;nhDTio zy3h8d*{SNWj*QONj#3Kn`}4{Fc7zd{slz=hiSrDVLf@Se_Yy1d%LjU+?(Qr1hwxU zW0wT*|G-;zW_>+k1n(Z@UiD#MVu#NXyIt{s*|!1tpTE189J4iSsePq7Z&_B+fM1qm zB_whG&Gdy8>e(X7cJutcWGeEq|8@?{wAuKlQ9dstD#e(XZ95Kga8{K&icwDf#&TiT z8~4jsGc8BijK$d@84>CTxD$-uAZ_1f^XQI=QyrkN?}Ipw4=7UkzW*58fsA|QZ zuH}KqX(i+*sO_Q9F`&vw_zpk9s)>nIkPO=@c}{o=ZpLMG_IHP*v~0`}$HjMv5dy|~ z2=&`0R2}^`&4vAIafa!l4cY{n&!QxiZ{MjnxX-+tqTTkqxSdEE2h%p_9W0fQ94%Do z2Wx5%~kD z(7M*O%Xtadi|O|K`CWd+NQ{z&thjHfww_-*C0u1u7-H1ly$8bUvOiAo7MvW7T!kMC zzI3?n*wq}){cyh-l%I>xQWxtUuXu|L6k4DZWI5S^zQVwbZnE3ZinIN^QiAptabx_D zyr4{d6)qBGeTJ??RZbZnwUJ^7P{lhxZ93Zo?h&~?K+LNyTPR6zMVF_`4PE%Bv!$9> zZa}rZ4~QW$%F{4720Gms>Qn$Q^=w5w|U=Zlue@N!>?JpacCC zkUG}^91ko&zYUNp6ch_L3uHjpY7D0fk!);ifPk?s6F}qW=8uHpk;(SYo&a$I5^KV5 zfN^3T8=o~`fNjs)OwCXN95S#YH9%!JmCqSfv(|DWM7_~hHCr5IRk{5_zF`2UUz4S@ zAPZ_YqmF1e#!sdD3-->}VTG7_KRRO+(hRtGJ5aHkuZFYT%`yG%tt)oUa+4vsy$tQa zidS2$gw|Rv+U##F1#MsHuFuuNXHleYv|ya@i0G~kz?z1i^^K*3ceOUN71+p?yHzmk zwLhFpx>G0D9RxIxAd9Q-Q%o1l!Stb>v{!tGh_va)?6}ieimp|>^^_%A@K2BT=sfuI z4U98wv96}>(r|b{sCsu;!(uE7UrRk5>UkWnanZE|a+~qKB$qEs@g6|(?9wQ8BaXA* zG0wQTR$?9>OD~HqPEt<$Ib33C39({zb5x}dP}=qDiz=A4!T0*@itT8C_6LOEZ11ikvZ21+iyXX8Bnz$M#1A&ar8Y)v1^a8+A=q6d}lR=2m* z5NTsm3w7GruQ6_8UvFe6oNyNL#zMtcuFwD5N(paTK+negi9FHLb5yfB>6$PZ`F%i* z?p*-wf$AHVhN6AKm)mSlb2agF(wwV~ajaFVThXdFI8gu)~Qbio z45;#a6phV&_-<4Qm)u_G$ehIYNq1s*J1TvcJ49GWUM=vBA5$1dRBm6Q>sa2GST%tQ zb6%qB;>q+&b%uy38fAT$G(o{GIR1go?+j<+%Xh>`@neB*YW-4enH`UQq_MFUDmn6a zNL(C<>1!1XU=?jiS9Mo>`6(Z_x4GGg?+{&(pP!!%aH07|$Kqkomz%4!;EwaVyRjaS z&2oJd+c6MV9>W1GNhge`lqU@}!^Xz8741$C3e-Ta;^YCW^7X5f7z;3=2pAdmsWq?X zjX~@;pC37Efc>FfYl;7jr5!jN7}bcPfEohc|HeTfud}C3qzr*5w7_}-nGlmCW7Wx*m$ZqcAAdznSL9^xW57sI*QHG7ulrrI*UKsD{4L*YOd8K<=|cd? zs!(P|{b_NXvv6*<{5im7Lz=Wqo7ay6lDa!-16)oj^-tbjwOB%0$p1s}t^Nv{P4?9K z*#h7Dd~k~d)|sgWGU*6kwDppD$^dL@>I&M4=FYeK_msJtlRgdS4$cN`6G;Wj-zaOV z#9*bq2$!k%2MOwbcvG~mDhcfRf7dBuVy3CtMt zxTy5h-Za^(~J?fb>rPzfcEwY%3$T2 ze7I)4+MxXUZha^;ng2~~?pB)?BvX0ZeAu4zUyofmpIyp?fV3zNYIAG+JxpUv!o>~S zQhP*-@{35U+OmN{2Cv%?%p^JsR>^=A4ajgTqCj~j&P>T9!zN_HV?t!|dHhw3DIwx4 z5YG{J^^tf~ZMIj4nmVZ19~VYPB@LJel+dU|jEUI;|2o$U@x4FQ&GERrzfsimIpqF` z37MGpPkT|$s0{66cfTSjjW-Y}$7!=V>|JkZX$eBZ9iv}g0|sjGuu&H_KY+nxyu1Ve z2BOROV;%FEfNn*Ap$OejeKIqra?${emSkLP$tR3y%T+f<78GZ#xzn z=9u$*JtIrugf^a{ik?;>oLLVk+1t+}BMdxKatLJ78YzkB@O8BB_=VyN@A#UOP)PE6d zZe={0xL5-sP6vcaF;O9>RIjlhW0+XKu5z7Bca(@Nf?WR~D!!d}RXc zBB?LB-Qs-|K+5he&I@1SMLP>1C1b=;DQdW&Ow}S8 z0+`kKn3#S*<^5ZZFyJ;fc(SsbV)aYBV)_2Ql7ag1`Qc2fleVw~Q|Zg({eL)4!zwU- z=*1lJiiCp#zXmW)Ay7RfW|9qo?;-w7VDWuRVb06U;KP6rOBbNqD>;Yo4$+~M0Da&) zc_`a^S%BHE#V29APB^DAWc-O{jo2kaL zi<{&pYbDyl9nvsUpWZwf3bC|pa zlYNBpZBNi9cOGIqgz1n9%&t}Kt_Y+TMF?cW*Oin9B|u0+g-cUbt{wvOSuX^4-&%#h4VmTa@UhU55l*Y&CBd;n0sSZs0MXKR(^9 zEw82@kHe!F6qO9}`U4Ge9A1-fcpS76jgmmK-bg19Gq`bj8V=mC`aVKS`lO~|7>dHP zJofhX0%>xC-Wz#7DJYU=J+6#WB|X6+DYai-AGNsOG25*2pVnpO$)dTuC_&msErm zdT`amx6{A>#NJPHl{m=tU|*vg9kG`EQNP*R6FDDQe&!xXH5k?SBa<((Y$0z7W4YS5 zR<7xVCY+wBebu6?;`q(dvRma(u7^J^Yx}kT0xHY*|4mMABsZRTwko{N5S0xyUQLOQ zbtM_%=Fqp|0Xqw)CVc^lhV=M!gqDIxH+qHnj||6JGW`!>k^h6?l(Kw={p*ZDsz2b4 z6YBl$S7cXnnRZhaFyJSZY#^reGTI51)bRlzGZ^pVs3?mvgqx5PlasT6{2Or9fZqe~ z2-3j$%Jnz_+O>*=Gn9u*KM;sm7))S_ppwHfA5;#IhW{4Tv9R#9NU>xux?7_}_|LcI zoXY5i{zMi-5F2Qe@9f+6z(zX7N%e!TdBpVn+y1+RCyernnc8!8=zYqETxC!3OE=Dk zJnwbq!bhzOhWKyq@Y&zRdeeSmViGBOc#>%Do0)C9HzPEkwF))ou=_#+tX7GbMLmNP zbM-^Qr>j}*iR`s+Ufe1L{qG6G-=LvHReelbH)G^u?(=Cxy!Ta6bD6HmmQ$zbN;fzS z6dAA5mP5wl5nHjsrIiCNEy;BL(5k|@8#WvwqRQkr9)v&E=^~g-f^?yBf9TG_e{Xuu zv}s*t@s;wHwU->W%Y zCVCRTUYCI7qWX3yQC2`^@J5|Oc${}hz9Jw{e6vfzexgkCOK7SVGgBVIpG>_Ob`r!+ z@F43mz)`A{Q;BXSYOs`dFt%hhK~dS&K-y$#@pMd|rxkSJ65xv_u{}Fb?(*AYFhEO< zc%QF#G~`@ukED9QjC%rCg^;Dgm~CAZQcP%CiB`k{)I#Z)E1^pi`L~(xf#>|4i!5&9 z57?jI(j}RrV@ewedEPr%n|~r8|LZkiAn883Ra`>P9zM8K{%-gv8IH5|b{JpR83eg! z5L)8dUM0W0->H_?s7PNUDcsU+arq@>DFCvH{M(`&#e0N_UYpaIPqS9aDREwE7CEb*W=)G#tDKZzWS}c|s;lv`VnMi`a&! z2*Jnyx{Obvou6-7Y04a8`qb;RoKXtizXvA&xV*&)A(&BSY7}Ghezi>+`RBD@`U1Sd zMDQ|M#&iVZ=)G@TML|{yqpc$0z3+`9!5pUP+2smd|AUvPGPJMk%51~zuhQQri%vIP zm!6IuMRIu`mj5smBsSg{5!C%1!cH{~z1jWwbjlT9xCoE&JY76*2lE6=T7 z6tBoG0mSVCUAx6T%*#iL{{R%8o0!QDdph0wk)DE5r*W_U4New-e%p=Gy46jweT^cf z88Y`{px4~zedqP-{gaas4%nQ>P2LxLGMOhs*t$d56U;(s9|-W1DdS3ipm32WIg%a@ zG(Oo|UN5I{f6tn?RqV?3iNomtB22sXUJG>nVe0`V&4yr8^$qe5B)D!^mRp={kUm%} zyzrA|#19Bjw8-aP$81b^V5+wsQz!x8ZXj2?`g0yK1g79vt+%g1L<#dQf=!H$Uk-7PH7bCjUKP(F#1aFlRq- z8iY{(TGKZOHM8hKX)41(pi-ITDoOB@W6qCs2C+1o3WSf1b|82Is{GE&Ev^{smo}N` z&k&G)X=T*HObkp~XUGBV9{h1rs42!pB7Y;pnA)>i?uGQvN`A!d`Y9Yw47K?7|F9E# z*Xec~7xxMv+wDH?K7NtcZWG5jNL+S;b~?QyMmyCK!|lSNP%?AuJ|ON#L6&Cb$JT-Z z8+G0M?<2#*y+852N%>iQTI{PPnxo)bA4+T$Cp^ZWM&@uatMAs!NC_fq z%tZ6D=aMo=fv)fP-Wk?L&%{;-?`BolgWrp+YTh8tRK}UIL(PM5yslG?ErBIZD1Oa- zC4htFX6m9D#c zh_CjI#WsLJnyem^l%zABGzCqRikMv#u9gXh<4JIlsR+^2X|kA# z-S3}V&%{=0$|`BM7=!D1EvMV;THA~8qB{Az{-B#!Tc=q(eAW(QGCP%j!~KqyOKQ>t zGbR;Wbda3$lxKpo3{GEu69QC|NqK0ZiVwfEAt?$@_`GkPl{PQ8EKaXuD&+qybt!`i z?MQPlwv?WC6x~@As;Tz<1vwyfLgv#Y@Ze|j>>E4OlN=hzLX{$zjD|Ab!` z+`M=J4`t^Fm^t0-5=-J!HCmNF-bfWFDNWYr2-$E-J&X*?G)xKa@B@NQCGo2gt5W0i zSF!1Lg0V*hM-mNMt1hdY9NK%O1f<>4MEIq< z%Xvwip>}}il*C1}?%UN&9yydT7GBUm?h6iqC^lkGl)f+rI7F;&kVOG`)eBts91IQ1 z@|^2Hm!LU~imb}U1kU2HoUvK^sdq7Tv?)^KaEtYchpWwsG>j6?wn^$2Un!NP4 zY!dybPc{CLq6lNd2JHbsZAAqXlm7*JKq0>Ag<#`}LVf&i{|RQR{1+>^_CNm0I(Ehn zy+>W2k67sc^c7)Xl@;36ReFmf-~BHx;){In{PZwiV}Z-c#x?;sx6?%`Uq-)3O11-) zW}s#aXXiWK;_e(vqa;2$Mi|U*i|&FNdp6Ff+c2+`AN39cA^%icAvxMBk<_b)BY$y3)Hq(tAntG*jb_R%0Y} zmAlx#C`rK0hKMNqyPm6IT=f=nEL_osmFWym%T~d_uSc?%%Q37|qv^s<$}q!x*V|Vr zpXJa5{f8qTD)yIgjfaO_FR!lIbABU=&c*~XUzEX5!Hwi>_WueO)mYczm;3w!;R{#k z$Bbs5z|jl%obA!;;6Fz!trgt1Pw+! z?%y0xp;kAfB2?#vf@i;!kfB97T{5?O)kzc9`3C1ZKYyeZ9qbZF0DF{9=IXS9 z3EUnzqr0{?CYqq$z7yYAXn1IF%ebrEO9=yM$+}<#_raG}t|eP4dt_4{o7eXnE0JTK z+unW1p~~s3X`o6asK084=VKzL;E`e6+g9?QQ<{s{sz=B0?EcI=xus-l^Wds%Y^b+k zr1{1OT)a>5Qu5t8lOWw5=P;@Uug~@pqH?TUqA8LozP?;w|dWz6qoLnjTmXlTEvXJ+Kx8-D$nlJHb2v1nOsv3o!H*rO>p z;}IEGI};=Nx3_5~YZC(%<{I!gt-scmK`T@X&JIP&*f=-ET_tWGZi!19NE#H=01%S9 zq$YO8KYpoAKrnmVHx}u#&kqT&wcbIhCWyP#SB%rMRbL>I}^Y2n5Oi3}1d5AoulUWIZrc=a*5h9+%}jIwvRR zOUzsAr7LN3&o;0lCzE<49d_~S2}NRFz{ZzOsjzmRckBQTp6`HEDxuB zgIhKwDogxf>5Ce+E`Lg82o{t+tTFmOz@KGaa*KIS3hdu&5vZF*`QeCKkLd>wx6lY{gH%$`Tjw1MBjph!V7IN;{ON-jB z9Fvx;_H9SBQ0%bz@Cli$I|_qwZPO4IWL^Pt8ZlOgSkYE^b7vaeS4jg3pnspC1PN= zIK5N7ktL2wcCoLwgr!9c4Of6MfI>I1-BA|8&0@h8o!*n3nqNJN=k3|`A`E=Rq_8|{ z#*15=`-G^;&G)pRdcArh#64TC*iXfkzql2AteMn8LRbC_!%3XdkAD!+bz$4Dh~^;Q0X-v)PP`La&$lq5gV zEq=@!ZmX__ot*oh=jZnk>GekFft`Ugml?u1v~Ql?K+av}a|%Es(i389cFkR3yuJBD zMK9$;7zZ50T?;dvNxK`L@V!SW^wG72BU&aSYLW7B|g#w@mhd>}mOXjxPV} z{L((|t>o_a_rVrzfjjHpC54E)kND`V{sl8uNCON$J`jFTbFu@>esW$ZDwNsNg)p2@^sV13DiUcE>+ zvwMX!w%0;$DDn|K7$vT)5#7ohJ3yGO$1U35w;~JyazfZfA_i`ZLwEcjnu*bl0~?DA zhZ?Heb%dXk)b7i2@r3FHt?5i06QnnEJ3celR7>wS0lS9>R0%+!VII8p1knDlKg zoc;Ox^E_ERZ#R?S<|0y)r&VAJ4&Mtfmn)$Zs+|F8nR##9KTR`1I!~7FF1~RauE*)X zO7$^OfQM10xoe}D;MCa*eBShSCy*?Cd2aQ*^=|8^wO8ss;~!nh@|NZ0G()?};z&>w`jRCX?*OmWTrOnv+0#v;hq-f5U9|WX}8=0*J zk(bR`d>GtPr;{(J*5Uj(KA@z*Cv@cazV^ObHRnW?I&PTl2DMdjK7ai9aetmgB8y6( zXIA=BR$2%a6BE-bj~y}A+S)oCNF&+;{NiQl7jv0kZbfL$=Ly4;`%PhHW;0&lM&NS| zthEUCThY&d3Ka2;rKK+%p2T$+4fd?(In!$Qy?ZGFk|||?5O?h!YV<0 z?x>ut(k&qW!KM~nBqPy2(vt-?6%~hg9iYC|O9+B=%&a=XeHsb?$P2?Q`jlMeaER$t z|D@tj6N5(O+rB||wV;FiNRPXBtVv%{&8lKWEA}~H%=*xLa==)UGlF0tZ}nCleziL@ z(+HnwNNJzkSX`Rd2D|hjXmHZJnzHnqpEFs?sY^ay3Y^~8M{e9Cls;r!3dl9aj#fK| ze>~YTv1%QgR!o?BhEbA@~fLWx!qMecQ38MC@ z6k%g$X{}ilkh~+Q59p?1;RkYI_Ewi&pgAOP9*Lp`Z$nG-SSOxb(Fr86Q~{iER{^_;cMP z7devVJ!PIe;-qY!K=9^@*688kfr63}?lm0px~(MPdb;;>2Qba=r9KXj=I@T)GU`@` zbN+G8C_k-MJ43{e7f}p*yBF$4qmH8%y?&?1vkpAUaranihAkNQ!N+VXk!qfPwu$IG zk;7$HIcpt0O_D;1>&(U$+PeP}0RnuLZM638RUJlsgM8-K89jcRs2{>=>Msf`+Sn__ zVUCT4ZTV9h9TonCV#qA!STGZ~7CMFr3Alb5=A-gq)-}lIHG-`?p|G zdlt{G$csj#T?8z2)2rOIg3Q^qRE@)RFZ+?bD zLQic9+Xy<6YkA2joP~E_7$yU1j6ARQ^rsW|?F6NzRa?-K0Nq@}d6y26+Z~se?Fn@5 zaI+PSF*a{exjGUGF!tzg4ilG-W)~AzdeZju3mcm9NjwFD8k>=Z@mCP@12|dw@T@H& zA!6-2a(C$0x{H}2FI8ot&||BeD)C#y8xMoS0TTKBRcpHsMhOTDs%K-F5epgl;di7UKOiZ#Q8Ze1~1JmeqGI(+i zI8ncVDNu!3+GP$2P&Q=OXWsQ}!BpPXKaMh>r0E0#oAHmk&Yq*sHqOQAT!vG_bC#TK z0~c(+blA^Id}}_^Rc}=5*{M7u-33X}AgSuiD3mFPc;T^WLRh7JMT^!INg1FJjYqH) zk!f1$M?VDD)H8ou0ve8GTAfMicB^`*cI!H-lACJlG`8&kX99<#Xe{ORn4AfzAn0v6 zi5Og_@oStGII<6VnCayMnqGlo>6w<$JHvC^ZQdgIm-!R+@~t>#&pw4VLmNGhbcRKV z-nD1mh~E&lN*iIzZBAGe%`)OqbXkEdSG)VMTUE#$X$=aBga9ymeph zQ^2zvA$1jd?WfBwT6AAGn;x0_G!;!LvW<#6=tBFQ+jDL%n z2D1`dr8kCyu9fU|*)=Q}I3Yh1uT}X51Ly;Sp(ZE>R%gyw%WPxl>Pt|M;xgTms5qTufqhrC}cXxcYHaCvr;sNC!H z{lNOChl8FlIRdVI&jCS$(uWvx~WA zhjR}o`{B-{3zF8CmP`{4KjtlY*Y^Z%3v4Fv_CV`w!R%v2S;6@6#I$N!p6Q8@m&|uCewZ9`IZ49Pu2m?ZnTsryB$Hp=_MfEzS5A# zh(9=TC|}5}chk5&M(}bz-EHTc^(;v+pA}YH;T^B{k8OLN5XbF~mdU}0`rl59>#z;P zDWfv$?d;L}%-`*QHhOJ;)(}O8D#iiKMnF(z`$Tc^Z5n#s>B?_wd$~DLRro6z&||Q9 z*zZ(P5O!(k=8HoIloKZ2T*Mhavazz(UkPTt;5Q@`UqqzF0*UA|#I!?0Ple{oH3Z!Yu! zPF_VK;ow7X+-b@_xfwTqJ~!WI+SvE18ZeFhM0nA&5;X@(Q5cv*bu1p(mA$aK!qb@) zTI7+VIK%&TbEqvW;IEvtqppw89xb;c-9scF#Du!ch7`D2sl-s}I4@35+90nfhoeZ6 zV2oCq2I{Wd(~^9)!`dnshBa6vKGjODUf@p35*gU@`y~t@dc$Q8p3>UrYHQNcwHO7o z>g6s`Hbz$S@r;%^;sanQ$y1aKK7jpp7Sohj1%d};!xaa`;dCiR5U<;Wn zA)vJH>2xdj?xd}=##AxuExOa8H6se1c4&FjZxqW$4vFo_HoU8uRF(xqd*z+lbW@|j zXwP@1a{yl=J=vvr&zkppGTVD$l>ct-^}?~5#tn-(k@mpvI+c~?$b|CCKK|YMQeFcH zZxGC4XV^^UJ-)tz*5w!y_PLz45G<^V%U&6zAC1Dj!rk4v_{?pGC+{6{b@aIE%LtdPdU+0yN1}GruYwl<4pZ@+ADOBQX zQ9PX?K|$?N6cV<7ai=*iL?CW>6{kCae%U$G^#O2`8qI@%wr6}-~pH-j#d_J}~McN`%osP-^}a?W4M0hrVq+1i63 zh$9p~~8# za%)mMG|(O3su~%SJr^fgYGs48I17hryMGtWH`3jC;%@%>7^i&S5(xG=F1KJs;79v> zrq<~7=L~TvjUK8-oX(36+yBQnrZalq^y9(2yf=fk5=u^_I)NqVt?i7l$BAOJVFU$>(g< zkd3qIyD{h0aB-VGg){pLwNh~9&wVp8fO|FRWlI$qRwJ-98w+pNI_Vziujz#u*;7lJ zEu8iw^|Z?8LR?*52Gw`B+PU>Tq+X`YwX{gu7sA41Dt-IRWcfVAJmZZ>S+)aqZ$P$R z2LA-^JJ@WVuTN_T$K#1lT=Ab*JtxY^eU9r@?*8A@YI2}>d4iayscFUqSwe5y2>y)C z;%{no!{dGzF#3p)6m{$#Gh!SPpi^*O2wlw^X`JZE&4o_oehZ9zD4=^--2PB$H4_~Z z12%y2M3(&cYpvB&^Ydo(e$|WfV_NA{t*EHK*?#OEx0nYttS1N;YGnyF_%Tf?)fl&w zFnKWiVPlvbI8jpvN8ahMYahL}kXlzd75U4|q0>}g>}Hdm^rEn@)>=}djSx}T%$X4( zTX!i~MwUu3b4H>0)~KoyKG|@cY-KDS?X@H9&>tSMdoY;=#gv|ZrsXL4R!3NK&{&7F z3_)|%qwvXh(lJ=&P$FTk`L5(ZPF+~nklOYvw3ic7(|HCf z^^g;2NB&vpXO|GnJWUofPBS$bN_(g|NqvON5n}@^!E7L1@qhb2Dp%A&-$#VJV++M6 zy$tP@TmzAI_^mdLmnkm|`bc2e$Ld`~rCLI{m=t)1lJ+!RV#+SNR@%K2XcD7M< zfF7*kHICc0TeTZJF`J68bEQeElaZNIZ1EjkTo!4)Me?c1XswYNnyU-=t~!_3o}9u9 z{-zo9{ezkJSI&fi;Jl*sza=0MW&jl_3paKb(6K$kVqr?X;3hZW7w<=)p5Z&v%a5>s zi9bIK$faIL5Za3(R1RPF!TNEogf0y=2)BvV&?lc@= z^5kS3Q1I3y!pe$R1Zrl;BwQp7{JAGhHE6+R!QT{fgH!A7c6Hw8Qw{ALHq83vXH!kG zPT5|sD)5K81~ryBPW<_-)FAy`>@zT$h-~qszV$V+g~QM68p9W%hk>Ca!jNREe}$rd zaP-rVIfKGVWuphs^F_l@f|gA}ey&5xTFo{w+47+|wdEI5*7E5+n>*n!1Z-ZpG_~Jf&=G zqA6ak73>CDmHGLh*HCXoydD=b7&GE@;D%Z;z9k}X+Qt`nme`b4=5|_Z9Q3lN8n1L? zm@bu9U?dz}oOe2b-kegL=xu>ZRxRtjs?|XeH*ibVDwF!)JndlD@w6{_-00|E6Bj01 zQmQdpW|x#Ddi|0`_eBPHKsFuEuVtcDU&4XS(o!R-1hgQa;iTVBIY@GRB<9*p%;L-h z5U9ZL3FvWneyENbLuYf{)9;Vl2!3*wH2c-GEE~F+z{#!%1X4}tpF0(Y!>Jq%y26(e zi)n~AdYlYhYMrEl9UT0w{6*X9kmEddW+^()E<|1B3q&LWx{^D3aRmTTZO{C%= zJOWi;K-Hur9UKSx>CA2dh{*N6UugRw%8L%K1%4`MF4&@xEd;QYU8D#6v>b zxgN47Yzm_K_iKq`S5uPIlHcbYLTB0WU##XeslBimr0L~LJvgM7`JtvI{63?ccSW5s z=Z+|cS_sM>wyKZ5w9D|^t}F+R-42_EP;V6M%9o1bK_=VHwPx{J)b2#H2gAT@7cnzG+b-NjSSx7kbvHLhpl$vR{lbws(ILHt?W0sa&+2r z4AnoQcfO(DVA(x(3DKQ%zA+4K-V`Axzd+bMCF&NuWWy7gN*>YeLE45fJXRxqfSDK= zlb;w>GI>0>e`;IGj`l=}t#jzeyIM2YY2RvY>4B$g&-M_QPv$%K8V+Mns7$58IezC; z@l&EeB@|ldpeA3o{*F_SRCv2hzF-%Zr&Cs=(rBlXzOCYKC_+^` zIQOOVf&HFDZyJ8RP_ICawacN}nPyu7eyq0Ne2m5v?*-xRvM9#etVMoyi)Id7j!iU( z6Z0ENKT9+{_XbO)XDIX(y$Or(JF9|C&2+A=>4Fd=w}+V_Z5Sw_`aO9Elv@~cYka+O zDwUSeAI8BTpw|piJ)OCTo%vmP!~`P)lGLG(w8;Pf@yj>+LwdP-i}FS$$d9XO7JIAw zuLd3t8jBn|L9ETn5f$UnJf`_OTh3OZM;Z+l6CK3N%Hx)LYn(RAeI1e={21@;UIdT2 zEvnn+2=SRqY~x;`l$rF(4Yr@X0hu~G8latZ61*vEya4dcR_O9_TdUDQjech<88!Dp z{9dTH(I#LYKB(P`mi@aRAfa>a7 zRgGK=SSZuNf$WMo|_xYZ0hT|>{0o+&)U3lZM9mfRf$Iq)W%3l1Zv=P zZ!|(#2qpG3#Tem&>|y|VMI(T))CZqYu)k>1M1aGh9w|*P} zr^BBYRBr*xW_sAD_MG_Vr7EMP0s~?8+{0#P3@H!}_p>Ivn$SxS^Cc0^l{b=Y2lo?+ zJpu&-A~h2-@_Yqi%~t_7L(FzTfJ?>%1A$GcViGgHbees*i# zIb_m<+OR@=$Dyt)#;5UBHVJpKC^J~CL{{9io!^Onm@!pw#{Lxwr`Xb&c7s5t54znK ze4K??Ie>56b9Ea!vxc+*$^cv~l81k47ww9iV$d$DaO4+(?v`J2;A*Rz8ZB-0-9&G1tv1=-x(CPiugdF zQfiv7V~m&R2Y9vH(-Gk|8wrZ@dxTPaZ;{UAolO)f4^Ld2S(+yb4UKJ}ZED8AcYPJTBs-t;L(1JC zLn_HvCfphgQ^J<)1?>xRgYO3+X6I!6UPoxs1qsK-WX^;Y**0|u3b~tXgKaF#=P{Ag9T*VjJ=yD9tZ4@#9e-}7R<-%vx9tql*7lWKRG#T_xWu&1jz)3Lf&ux`t!&(zCx z2U=@5&4FPb;7nTK&mHv}sF zls2GxkT0F&%hU>jzP;t)6jRpGPAj^lZfVu@oYuP?=&o0qcl5+HoG(c%wYBh)fgILP zf3)p)jogN8`^C)L8*)h=O8HneNU)eaJ!;!e*gK-t2BG&)<$4PdYx*2OWViY_@;(Yb zzj{C8O>1bH2`UJd35zSjT58Ct#34jyAOJGopyr^t*_KO9xEN0V?2EmGzA0Hugd}lZ ze_aVFMBw^FRW2(ApCyK*GhEw>Sp9L;-taXOv4agCfT@6%qFOtvDis|^4^DA*@K3Wn z>fMCYLds9KBP~ZV-EtB_CX`qww&4uS)PW@8)JGm(^5n5gLmzQ1I=xw#B2tPwAvt>q%sCwmPQ%}gXL6f2RsL(nvwx|^hb zuLSN)LHpT(HqT{|m;MBh;HC+xj`|8!yXTAjgt{Tv`OGxv-)g{sNU7p@gMD4=WN#T_)I_ zOcOYRJmDd0@uG5xoSl1Ukm5}@GsVNu_ethm$~lb)MDk8_TpSDrS#?QgAoJJBi|9-qtUf`U*C3ATE4B^gNr->V#E$c(&Xyr}~jZ{9#! zLPb?oCm?NHPKavjw063_y?;Q~Q+#e~K$*O^_cqY;``9o&IS>zd;u37%U6)ZOxPEPV z7&`+tbxigDfxx$6eHMO^(z^+IB69B`kktvLG3!P%0EH!4VHWy+@loZXNyJZ*lIn4L zF^3RELwaA#p#pc1aVdKnQ-=Pd2}vFkm+*69iX_dx0sD>{k5d78ax4>PNX;uA z@2msTyBgVa_uta?vRnJ>ju9JPaS+nUzUu{1S|zeGNeHMPhAQ898zaG-Jvc@J7wT8` zuj)(g(CJTSN%G&Vb(OR%Z9jfkfnT3|xJopXF}>ERLxw-qNw_*>rutztCMV%KL=+a5 zf}y{B>*>|ET6OzFt1##VLy2;m7a#qFp=L(lYhLFWQ%XrY!<}uN&3TW)Wc)vT{Z&95 zUDt(+26sr%;1VRbyCpz?;1b;3-QC^YCAdp)x8NS!U4k}lO>-*W+57H)_hm!TwWwMp zYm7O^BVSl!S$sJ*Nx&|fiCRKM7J`@8GErZHWVhFYe_C&|%6Il?4x>HBYxMU*uc#pX zj|mYZy%Rbmeg4Mp^vI#fjCL-FQGC8uDU0-g!Ft)P@x-@?Z#5eYs*;PRk=963{W~<| zJQnBUk}c1}Q9hc`0b6Jw;0Ttu_#q_m#|{D#m(y&{LNaqLq~n9jhv=J)wTBb9Hzak1J~pF=+o zW@OsK%sbZyCw=aVZ#qLWE}X0m%1*volzlyrvntgVNs@Kp zzpbHbh6rIqsi!mO?*uL54#SdMzLjX6g@c0M^X6rPXSH4Hew0y+<9`c3Mp>lzGSdzu zg~rHuM|yOG_SS03qk?h@IN@zjPOwxZthf*J3VCm2SxJaj3|Ib!c{PRnRB7Ako_wE? z91b_}v>7rZ@o{dXw%4Ne7oo_%4VGT>UCD6hB zZ)t5IR7SvTGuN?bH#~o{SC`7W+(GP#AxL0eJ83p-{UEy6e^4_b%HM@WT#a)hKUT-q zAci#^I>D$<6E{0eA6&`#7)9!rXB12e&(fkhh=}$03RpdV1Ma@}lj8UN0W0&qkH$Sogh~ z?iIaqhV4}LTY{Z`A|hZdf!^y-q9H`iDJlU?ONn`5E!?q(M#5d^&du@I1OGFY;KN|| zW{6iI^ik7MPH|HeK8G%{Fp14AWs5l$+uq~zi5=`t8S7eNQzwm~yzBfSCa)MfPKRZ2 z{a{-N$ngFVdt_Xg*CuJ#S=QifQkLiC@!hoFM<n#GaUqP?7@e(P-jWz zG&Tc4#N=rYF8!T_vRD#SQMfun3WTryz*V?8MT0nwQFAa{5c^n`(+L*17bxG%at4S- z0W#@5ID~!8H~fQpi=lR6D2K^$i)UNmU6V2|b|7yDnbtL*op%pl`44MTXX1}dejXWF zRo((O$8zj%?Sb8{&MC9FcF5}GTkVgFq2JZAY5z5b)z-R`a;|jw70-R9WPf|Ve_UVG zhh3N3MniC`Tuu@W8yufeUJoQ$eGs#mGXC?+9s12h`@05xdgQBgrax8$2b zAP~?jU#nOMKasH#*UgcuIF5Jzk?fp+<+t1C=K7|UL}g*xT$ym@LEZlJ1!e~`Ebsfb zw(rMlMa;DcL<%xyR*$27_{Dd2eD8xY(RMAIWl8V8oYvF8%J%QyVLhSQDXb70(}Idb zJ@lz$&hf(PR&w3Z;{;yj9K#Cji2Mrd$6-#d_=i`{Ze$@y&*GfpXR4!U5+xzRF6uYS z1kV@{k31wm6yOI?%Ra0PH&rx|Npz(dFbxvauNwrhz2wahP1lo$5$E(j#f%SS(F=uC z9Q&>E*Y_{imVV9MY6qR+%`Z-Mu6-4Xcx`_nYF?#vuRTdmt>?#_!5U{}*xnqmlabLT zN|dPgIid1yv8G167)?6fEDhsqmdz7iE!2_XnKs9eRV~z~B|_LMt4_h}mMY5V(?As{ zy?8b~{f|Mxf!lRoH*-Ehj4SU-naeTA$ECwPw!I?cM{^$zrr{Rk!`0F?0au>F%j+g$9}94EJ@M9*B${=CCR! zfpJxT;*%A$rj&v+Z(jh=m=nT_SR~&xy>-YqgX@-XIk{xsz=0<I5x%Uozuqj+=6&LMX^zp9e znOsKY&?ED`+Wp31TeYDfSKpLw5C03k&QRB{DWt$6k<9OWm6Ml$QcNoYFLlXHR*SjF zq}0LcT{%-ZMSB3pfONzCqUEVq)FRD6=_d>6@~OAl3`l{`tecH>)k*7{@YJ0_;r+rK zoazEG?IgR5D$9OGvMUFRD*$fZRk1~jN~|o=@RV-4(p+!)QD-hhtGkxqo{z7~Z-ua^ zx})ji<1!=T?y6g&X!ab{thU}1T>QUn>N(V#7L*rT3w)R^G%aiY_cKNPVz=1CN5!KE z)0b)nYw0epub26jo$X->(g&6n%`0OC_P!!poUh5`KmP^Lsa+oj!}@78qxuW9Rw5|D z|8B_eqB)$+OGT3WX|AONPU6G0k>6_S)pW=E?E|M~o^p%7IH*P5@H{;sKn+4QXlWV4 z|I1xo?6LEa?n0BP>>${*tiiG2Nv6k>?Uj_8`OO$g-$wCqE$do=U=p;tvh0xePN}Wn zHGcPRStaqE!0J@gqQ-htGaryV>chgNTHcjN5;dU8p`-T-U3tRGdB`r)@Vc0)I#J$P zEq4+m;B~{{=_YU!*dv$PNGaLh-o9N+l@*Q&(1);l{ki(|pf)gd-UFW+S@LB;ONp{# zNbs;Q$ogNhwrl-uBbX3@PnsVU@W7llBTdG*ivI;1dXJjU1cPh5%M5qtL@W27P$9C9 zEVA&DQ!Vj}cfyrcVA-3+2wfCl7ysAS=DWj79^#ci25%@G{|)2++NReS;2iagZt~4K z{1@j4d)og!IO<4p|KET4e}2m-8yT!$QkP4?-v70We*~t_xlmEN4e7Q3hRe80s|6j% z&JazZQ{nJm@ASX|V&cRBt`MK|u|F=OCg}F;%=&D(9){a?-M^#TCL-@T{fW1&$3ubq+O5YH4XHk&kb#6zo^% zXBC`*UVFBVnx-PKyG+3p%=~DvO;N2X&>L{?4yK6fQEYH8VR*fdf= zXV5PfdrjnU#}8D7vN4_Ezv*AD>c~i^?$&;peZ89)%V9c*z^pG>=MVc`iO}2IU#|N; zTR57aJ6k!=hi0@6?z6P5=a?M`DZ{wHTyOPn;707 z{F3_p%&@W4%HQH}`3DZW?n<*O0Hb2@Ch<~MLoQM7Y_}&YN7LExQepBc(-DGr{U>pI zy@Gy%fx&~ljTT8#s0|Oj`;7QRTf>%5m>#AauRlKNJse@2r@c7=W$Enq!`D;E9@o$= zes+a&U8f%vTiK<`kYuN@gy~&+Pp&{8(FL!kzANI-BcrLOWygPE)(CVz*ATb)r(!7& zNWFKys)oxIPcAo@?DT`NNRtO`nYrpRE;no!21sm$j2d)#;aI52IaZ}~dMlR8VfldE zGY^@!3>&iKLV<)u10XapJUTkKrKKfw8{h)<|lH}#agy@ZWBNAcDcJLgTD=96;{zo z+D*i$w*gNmfuX&@5k-$wb++!#uj+bH_G3`XOvmv^%Ar84s%U>^(1Mn1>|1hpO7Q%g zG<;vVoK?+n_;L%Pti&5KUdg-9(h8KST&YtKFokv}KLrgZ{D&9Z!j*CQ9qktD{@(+E z0!3%OL31XmeThhKK_(}Td(5By=3RkT#~QF0h^($|I)7J6uJVQEX)T?P!M;6@#B_uQ zy9ny5?C~=FE+N%IJc3vm=yVWYC)$$lJjB-7+d(^ZSW5EUU};O7g|+#(6rE*E?Pe8- zd~?&)3yDlI*jYr3@IxCKm=IHEJNWuseyKR>CUCaaJ$;jWFCC%&eKf(r07nU9s&8N& zwPv3sbh}opN`E@7#;mCEA$g1zs}ctjLez_`t)mc65_7w!*XOaFS6B0-t&v|qCQKT( zShMYt&Q4w+gas*%|L8*M<)<~qZ1v0mg>*?dFHKG*g!}h9#x~B3I4<}*Z=Kf)hJbW6 z^`yRuG7hZ_Qk_HJu>Dn8gp}OWdfypc-33!?b+eRL zsOe*8%o6Z;8(V!Wk!c&r^rpJ6ZTk3k7 z2X8V0EAjn8r_a_yNs_jL!_wR|NDKsw_c*Vo?c9k?&xkMP3xc%uR5_G6s-Cm;);wMY z#)|1kh}L4$tA*JE1fq~0Tul1j40M?&e6u6K(b&t2_h0+_&_+7fNL4&z4-c>gTF9R- zsp^bGx^EoJr2j3w718l)t##R%f+MWAw+x z&{2KG_~x97UH=HnU!X)RYO4XaN$)AzXelINH&3(AZ>K**z1y3&-T5!{AFyL7(!Iz{ugRn0I zBcAs1Ot>E2oSg{M!+w>eBM#1q_9o1K3YkCIKl?JRuMyiuhj(rDH|$wScKU`NxFc>k zSS_O#^E$F885m5O5|jQBK-^w=LR*3MuNfWtqNanxP)#eH5tvwq`CcbG6=}b)U#ee9 z6}aS3zI|)qn;83qvZL4hb@#x>1Q$$|n;gt3%uftucNW@Y#CdN-fgvGe7qYK4JkIba ztxnGpAh6djh1Ms$w_fXg&ntXTlN>{;d_NtiGO546UhGhYOj*?)!}AX1$kQ8M_`dXT zq`N5!1;YE`Pu5X%2!?Rx7Ik`J{-8eG=5)J z>*4CG1al@M%ShpPZsrH}%)fX3E=RW#dw%U7<(_^2n5ASaD1>*!S*oR;~G6 zz6ftd#z&4lQ6ANrz|L%}eE_{2J;Jf<*Wv5~s8`lBT(CA=8dADI6Z z1&5|)J5L#T_TTVm4SfCHWK9{JnkE8!ZuQl9+S$n2e~5zo-Z!JAPN~Gdkcp3lB1nkt zu-?k9YdvKUNzjz1#eSkcq3+z&fydm&LSGbJuUW9aWQ%?y34b!fSB8!bup0Ri<7lxY zcVjgjp?3}3xy1?SK=oB)hnqIRzyW484#=;}K~4?9^{e=^+9M0lgLUz?RucqCtmxEl z*yj{4CrS0sb&28HR$8A&CAnVw*2((xL~Q8PQGT&2zMbc#JC4WDZ`OE&DgM~yGOL38 zESDZh0i0tFf2E9dan|=?qv2R`Mn<%Ou>bMt-s9oX&?sm4`@dV^^B+@wetmg35LlI~ zcoL3F4o{Md>zYdXE*Zt!fwZ96RhF>fH@0POGb8sYmPzR%*MR*Yrfo)kX=sAm041*2vlQ8F4DdTaav=$_Bh!%P(jwQ;5~D18s?18{eIYty{yrmxQtZ}on^dC$Vak~Xh0v3!tI zCy@QCU;wEdd8c9L8_s)ydTc%5+S+g4G4H$J9}d=b#?!suicQU=H1Llk*k4;)p&?+WP!$R*x6fMqwyC;t#bL4;FK@7 zepa~3-q?10ngqR4{lLo%py?#IpoA=kYmFF0zEd*Br{`f+hD@xzk@Fui)h15-O_I9| zaXKX-ke$!DN6Twk3OB~faAO)M@$_N3Rn{ziw&o>jKqf2xI&>a&_L>g!LOo6Xx(H`H zj3ZZ-km&zs_W{y}eci@VYhL=WQCXW);Pb}?m_y}nfSM9gZQfa~&)H(qLz*!o|MgC* z?Hx`PJ#I+h!|IPZ)w<>b5#fNJM;WQz>K#inQ8Adn4OI{CdaY&Ud98{8h_el=eXUik zCNQ1M9BddK8tR{vHQ|!~P*YQJXB`n|isQWS_jm=fxh8l0Ec<$fz!!y@sFmxitTJ5O z+g1~rfxd0T0avEA^b0bqn(4*a2cAH!cSJ80SOppzCp-7KZ#&>f2XN9vGRE_X60_Q> z6O{oYBfRK1Eshm-0~7h6!a>Q928x{bXSCDmjkkf>r#_wm83=v*QkYiKMd)U;;f;m| zZ2x*W6Wdhnt_M5wYagte#Pa+(MH-dzZL%n7U(IAknSQ00j@8gpinc1+#4;(@MgwvJW4B;(1*uQtqY8dB$B+{~O(xzdX7z1loney#-?reSwD7i&%a z9Y-U|Z8RJg^+Jn2;zoWMpSy#QBVK+F-Buw@vqD7S&pd zBTs8+-r}P606m4N!Z2s3&he%VDisiU0Hkz@N=1y#q*Sd4ZFkXuc?c_axYVm^46G%^-rpmTwWWe-o`@RT6uethr2!5ONfqp zg%{CP(nWDLT)vx(x&A<3{I$bU)OcM7tjeOY)_T62-z%pJhrx&Xv#`Yds{8Y3!Q)PT z;_4U2+Zo;4#7IgIKHzU{Ei)TU%bWhm%nSd>cMttF-KI;M=X`5WM*7S@a#DQM;m^`s0I;K@IFtCmQF0NY&J6vUry#OPy&jNWsGa-hi^dicDQ~<$av?bx!rC`uc?6g@_ zQxLj!l+|0P!_Bj_1UgHYZ{OML^W$_{bh7a}j|*yy2;Z1Ack#b;<2xs_<;5G;`|@tCuF248Be(xZ_@G$qi@&0enT zD7Xh`5CvlCC{fN9EjLoh<1GH7zR){jMnQD`_HFZx_op{=oc0q4Bo%vCaIp3`9;6KImts-k>^(ZP%u$7 zn-Y>ny`EqWQ9(<8ldguf7k=fS=}q+cG(p>yzKi0B0<8`i2!|gEp#%t4sj@ZivetxGou5g5MKECjo2SU@rAVdoUmc!9x`xN zu4g*yPv@e*DN20`{b-BWIDDOstSa41C|=wyn?4-{Wq*6x>(=3pDw7c2Pg_HQRN|X4 z&Jp@2+6SHeuFxuTwv-d>+-*3B98DrfApCm!kAnb*fq~(Cx`;;V4T1t7YNIWWpAZ6w z`CXG!V4r!&r=M*%bJ8>3^khC>{$BcQ_Iqpbb7Iz&za6_sg2zOVpDiB{;3osl=@PL& z>czAjk{5b+LG=6qy(!~Jx#WGTmpguyd#?VB9Mt~7`r@Q=q3#`uk4yOx{dk*F9VR@m zhJefS)F{@Ul#4a&;OvB@4;w4nu*Lk9UJ%}#-Qd8M>tBXl6|sczDa>Z}_+&1FS?7DZ zOSIsP%er$(#>%Byvm3l=|2N8f3*4?cmfl$M>8W=2477Bhm^v88LA@xPk~~UI!FmjNj7)_ zu0AwE3Al`_5kR(aI(k;(j&Y=ndu}4&J}P-^%B(x@SSC2wZxxX+m`BXG)#Lj`ik8ml zU)=q(#tp>|)|I*_-6;MwY*aC7$f_r9s-6Tl$9a?F@p}UdTwe{1O4+)YXqjxD zQUdQZ+G~qI1&jn%4Gzf}sga)#m#6E0UHMAP_)ynA-3wK(sQP&6Z+(uh6`e9ges>q2 zmfOE{p{(s|N^t#W^!L_>luj)&-!0PhA6qpB(v)X!Sk2Z}LYVXPi`S-Yq!d>mQyz9G zp{o(xL(<^{uM2se)`Yw@PV<1XV?p*4pXrwCVqrssbqLKf{S&$?{#o$GkmQL5U3w3eQ$L&->BEFM&bLM zrfvkstQv{OTaLWa^TuUc2{B%Bmuw8I;;giUf20^KsMs#X@uGPL5FGqfS_23A%=#X+ zpCn!wev9Ih?mDh*Hdo@^`!_EJ@Q_~Sa$LAK3Kg#EXtRP`RU%%3(j%7 zG4-bOY#yYWuMItSnWmxViPJ8m8*udtq?$f8*wy7=E~5%y#O;o1>#eo>I`&G8yP3o1 zH8U1sXh12DO~0y8>xAHI73Tf&qi&#<+`HnkzHcy`OGpmVleVd3N@o{t-trj4O^Tmr zel|J{*54Upfb(QdN*uSJx0XjzN0M2$gNi0xcXUS3vBtSFTK7)^OSwyoJuf7$7#%M6 z*lysZ&*~{t?DfnxftFOJ&GJ)E=15;2^{iSsp%k@hzvWwg_Xqo3eb|bli|TfWu?a|> z5H{xg5|I581J-VeSt+!p5|%utMWhcPLqYlu+UmK+Q&-p6EoR0`Bg*qcf;03fxuD;B z+p6~4naMl=*R(EF#c;udRc!tIlOrv9F#Ae%1DTK7(f(43j51t#^6K?Eo|d8FFV*Jj zH|cHLe|a;i=S}LE*%iGCVtn{vtJB0~8Jj)$BG!>etd(Shsy-FOL_c-aC3yc-K=o_jdg)?=#1bHVL8mf~*XY@X_LHvkf(X^%jCN6CC%>1Xx<*G-~ z$X53&3Ic9b-T}pFP-E_#rZ3xN8qs6<9%Yn&baeEv>2t9SRIN@2wHb~43l2lod`-iO z%Z@cn^@dtcD!U~wt-dfAWm(l`zfVU{fMMi_eGR)-S8cX=Vf?s4T2TJ`G~jt*Yt35u z+?}52S=60al{Q*HTRjO4+}hE`X^f=T*AKSTG7QGt0%lrqbLh9x;?X&w$tmd_a>6`Q z%e|)KrW$Xh&gM^o0~`&qLeLy1On;g^HIyC`LcYRTP1e`umKGgJ4QeV8h3@02qYPVl z2n+zPjiyw;Tsjxa^i|sNcQ(IPXbk^z(6r=o?=7|*X!ny^Y`MHX!QEM%z}U!_-8!=^H+hebbv}X-dk;o${5}O!ibmW{*Si{usQm zNr!=OT_WF5sq|TY3FlT@`l*rFf_@A7uy`vQ$>tl(ps;a=PKMgNx89!)?mP|5ANuG} z3u}z4%y-YXO@+C&#~!nMvQSovsG+(*kYb*N57CKNbMzhRoch$Xi@Hu{?83dM&FG-;r?ldnAybiSZuP*TK$*P^>OOmD7b%q z%f4Go?!UH4CxQ_Hakptv^D+G2y92Ey8Y>h<{kBQPS^zdpMJhlN@Bmn-~)*7d&nHj~glnZ;cXkkE@tN(w~7 z(18HA5OH4z^8Xs0vcYya4w&0!HJDbpgp7?169|U;we1IFSAm~TYp-A1+%8FM*INVn zf)EB)R|)=Q8vrQze_4mE9UXf>_#q$M+xoh;Ob#Co2?Fy(*iRx@KYHy zpy}!9i2$8zJdIgGQD69r_|{xR4GjbIwVi?6 z+xkX*`@s9}DfgCl5TJ|w1i*qHg?$B9TkPz9RC#!`J$fJVw5pZIM%ek3C_V(}WYrzz zxwu>G%FN%`j?JN9kwNI3>A~^pIW(AQr5c$WKk}H+ z5AE^Gny3|qSkm8L+KQ7ybZ53!gT|3^eIQ=;)+uez65y~x7 z9Iu%~08Qt4v0UY(ZEqBj_zf+B8$TRDUNLU`G8m5pJ+~mD$nj$2G@dTPzlJ^Zj*&us2M`;8-LALOJPW0_Tt_BH2PY!(uDIOE+4VQ&Nqhf{ zVXmGml`Z%Fdt=&DCmemERDxRCOrX6W`x^>et_gO4d;aU+|ba# z>Tzw%pkC)sk5(km9;>0D0oZrUZ>6HDZ~(hTz{=NG=wEc=@4>+UY?LDRj~wRjNrb$^ zkB`mU_%P2J2QOC|5XMs(2U@m#Q7nV7v9QbhUhdCfkR*YEPNKBpVk!y>3Kok$loL7p zc>faQ#sV@qtqOEHTphJc1Qn50b+S`%@;p6?^#SA>o!b6HIx9>Tx833NJ`lZSU}RLd zxxto04@a-D8$UQnljq%^VhZ$l?n$vJdE?U?o0;q`@8p#Ao}A72uyhhuGI*GGLY_CY zDXD+wd(9JTfN01s{G7lv6>8}*}GJZCBy6~Gq=s8aIANE z@Tf9+$|0r~^6bgn10TFdexK~`#DF)r8`{(6eeY&KZf#TT?lbhLgsl4ffq3S|Nl5qd z8x=mxs*%h8zh>^OgJ5v8)>NI(4p!2F%eh}(vi(~4er>_M(x-u~!NL83RoCnRig;#u zI%~_zv&Z-0sa7&vR&0LbqHJjfYwx+MFRy;N4EE`%)r}wXviq9W)SC;=Rz*7lbBFy$ ztp<+JBq2aySDJb_@QRwPsAfh#=ZkC-w);)9)is`cy%OH)5O*{>HP)BVMZ@BDg}ThY zTE5wliaVp?mxa!F(2gdk$frs$yvoCco#!c3Y}DmI1^w3oKmB@heEI3Yx8@c}u!~@k zYw?5yl~0XtQETbq`p9GGNVz_Xc1|e$naVX+Z`ed3=$+8mkH=<)@3dGW((K!VaG8@^ ztqz~KPfI&F)+E^v{~uSPHRavpGajDj-APE3)zVm?Ff9xbG4amf;ZRiK7k00EYu^EC zIfnV*lgPNp$lbuMygbspNoKlU1bF!Vg$0!hRde&tnk}|50A&T#N+RfKFY|4D6mHC3 zX;aUkDqZsI+m81tDN4RBLkP_%PXsco^Btmwh-EC1+<(qVJX* zHIn%%vP)g0EF5gBNsa;;&S3M0b)re@ulIepS2xnJUP-1wpAc%uH4ZaiN`kifWL zso0*TH`P0cS%`kU%z8X5wZ51~bux=iZNC$XJ&0-Y9XQO#fJ30cJ6r!nK>Iynz5n;J z|2Sb~O+AQm=yb-HvEKRMs|IMRKN*TLUCqI1yK%8*2l-h6@$CUVQJJ09)ck#3dCM<0 ziQ1`5_N<2jlHSo-TpRG|uLt1Kw}!h_eVv!3ma>p8Ru4u3^A#jRu?noNnj%J6y1H$7&REPi*n+`x#= zwv!8cW64Q(<9#HrD?-o!e=Jyp;%Xy=$R8$U{-=Qae}}Qfd=7Q0CKi5y*^F){LaF{_ zw>22V@rxeEYM|uqD=D%Z6Z$$IB`s@+R<8NG&wU|E^<3uKc);jVGD0|o?F_M$feJ@y ze{`1I_FAz?x=-wrWl8wHzLqc|X17Nfjzb-voZb@vk zSA+%BSS?h9EHA6`ae98Xt7_+A=TlHM@Gw#tcHSTP@S|EY#czHu#-#?fs;H{4nBI8j z3oe6k$PJ~)lF)#?`GaW*mE6vP-5QAA6ZK>9?EIhF&EY7ctzGDBPw0%}TD^M2w-;lA z&`yq+NwMX87fB3b)1sX_Rvic9t;OM?(+sRIkliCmwlMKfWn@fy*h$mL=z?nHMFhgc z>h||w;FRt)DK_}iK%Lj~0n^DdOpa9eUCS5=xRAuI%85#UfeSY}M7}Gxp1FCjee#Wh zmnuoZ?_o_yhNyTWGsg59kSzw=Fh_AZWnKbMT1SN^Z(gC^kri?05 zzO=EiiVsGLm9{Rt6`tpoGUbpj-6cs)GBe~>&;W=s*x4#3nxka@^p)>Egv7&|D38Y* ze&+=;__y%LgE(})b~YUfPin9t$_R$+??l0msuN-%-(wCsHK#q&yT`;_dVG3DHI1uR zPw(r+H_)GC@`+zR_fyTLtG4B1-^?ckGq9xtWD)^h^%v7&V_iU3mh!$D#LMq|hNi`k z)BpiBI!?p>X4AO`kuW|cx5WvcL;%z|`~Ci>Z2++FXGBgF8JKfB=ZesbPzo5hxVUkz zSU=>M$2~lZ7^SzQHC1WVM_2X>zD2cSMuPbSBJyFYXlZH5qTn^hzDT`uta+joc@3T1 zR;*;on%(N1ysibiArrTwrDjp=2=elN5PY_xgZT_Yi=w2S*%ZSlN4z$@bld^N(e0*q zt$1FWZ)-d2*I&6G3&p%kAU?eli2Zi@!5OG(&h(AYsYqprlo$cjcr*e5o|5Mg}a6xz|I_ zWaPM{a-;#iqI!K!#GZt!=Sl~_1EqaPq`Vhf*?hgO`O#s4T)PWD3J(FccXR5n9&-!a ziEti7kgo`Xky|3aU-Bk`t*;3NRL1|~bkoXx^G;mV}o?*nRwK&J@4>{c+kXIP}sJP(P zVZ{1|5VP~Z77!D8G5NB-L&B>=&lI2GDsTfs7V^6l!QQt~z<(E_-i+^Co%J(Tjgc0G z{oM8UdopRGJW`1|6q2+!KcjGCkAfS>_8nJKtOlnXnm3(g4 zJNzf}gJ>AZG=8*%o_Mk0`8o~kt;a$2hhr?3Ncw@I1J-9yS*c(XJ7m8)yK`GP|C5e% zz#{q_>PavX76(x=94suzn_jhQ8(NqaVZ=JO^BxcH`Tl&|?cF255NpfV)%NxIyT2Ys zbgxKe1>?9~DxdJ>RR|#9Vq#)GGW-Q7VjnOu3u+_Sic@G7*<3oYL&mY_H(t_oJtMSy zJMUMmijRFpBa>+s-ykn>aik}#PpxG1sWj!CmRmMtSkHEbYkECDkNnd(!ZSh7+AbTumxyYbG6Xba^k{^p3=%!IbUw&aRt9veiX88|Ea^| z(tZ8=-;$R=h7~Rh1Jzi>54HQvMwr@y^`_M1-S{GrL7<0OTM6OBnjf@Ql>E+aisZ#> zPlTJx#rowlY)jm&|sQ50V*j-kfqOpf2u@ouSQ%9W2ls!MU-*9+7W5K24T z0fz;s4a#k^$5XQ_OP+N5L`oJlS)=Ake`}<;?n7tc+2VWd@UZ9uZc5B}iaWMYaMs4o8y~S8q>j0T zlBX95$e%|GBb_B8_-F#WpZg*W2e!hvh=?5$FcrnIq*d7{p>Q=bHI4VbTtEP&9mPpV zeq-h3Ql<;UAO`L;k2 zC0%UJhY~6%%44_VnE-wmfcQ;+47{g!7%dD^O@a(CTyM&k zfafKyH#eK6qs%`nL%msR$38hxI|Ss7bv|!Go8}m4wZ`Zr9N<@{KR|G;vs&tdVA7SQ z3BI0t70B`r_!23;hQb!>7hmypnqk>~O-M=#IpSX$23=j|3uk@cR#iY=G?H+o*geF7 ztUQ;Fd1o^S=ZaS$UCn9^2i9NEE_^~GZ135RwzFsUqwX9YlN)K0={LGaW}(hL6?*r@9$9VD74Y{0oQ@wg~~!QVgyq}i}G-z_eq|eZhn+Z?}sZ1 zTCcdEuurL@%@Kiw%SJ z7ZG`|X;U$I!;joyMl;er>G@mGK|#ZgWI~ z@E*&RFMY1)9`@y?rans=Me5bnauI9(tYX5aUkuxq6*3)<0j{QPHk z6+8I~Wj<3A6Z^gpB3Pk_{7qL6BOlDi`BPv=dv|X)_=}vl^Ow@!r@Ro#UFkhOw*6od z&mE?M(z!zwy%&?yQq^Gj8Y(i&Q4UQf-|&1f(WvL6tppC*cfKg@G4Gtu{W1Kuq^Pg* zI`!e93G_fB=r!aSY7zaLs;F>lMp0p|WWKX?|R7v-5S+y^1HKoKISLD6?EB>NdX>&Jr zwKSa`L6y?VFA_M_zYq3L!UMP7lvFczvR)ol_>O)-0xHrNf4R@uX&(ti5}O59Vf}2L z#IgYmTmgb=js+g?Q7b_Br#^H2)7V$9BY6JF@8JK*?|!vn82`FBY*5QkAej~4Y}&s4 z`@{bb&#A}#;WB_!v{&s&EQ$Zp#1Y|TlAscA7|q{ok*EI<_|0|d6fFIh`286lhVlkQ z`7P0b+VQ_SS9*&+0uuW74W?O*|JrAAG!*$`S~PQl)_)Ft5a?TxJ#megLjJ$oR`m0Y zQPq~wc1TxMKzEmI~*m|uQEr;JN zG%G7B`wd{f^Vn~{yS}-(T6dcM34qxgHAO_AUC&oT8XGwpEf!HbByWdDkb#((SU!MmBY51ODCSY50MfIf=9Io9M}7WK*=;6v;r{RGlajju zs3F%5l6ybD2EiA9xpz+}40H8IQ-_6j^Rn*D<;Mv{K&nT)m{%jI{V@tazl%@53CGsd{p=;RyCX8X3yr0@s@Xi_rCw)Ztewj z?vcmCnRFl%VJnp!RJAno`N0KHBe8hiqJjCsoz*A#Eo zHmf|$H9uCtX}S3Q$@voQzNh@1&gxU|aAjt8DZ{@OANNFMwtCweq}$WOD7p)DlYq}@ zDir@-6LfoUqRwBKewp!*=7@UF6dAZYZ6+V=&(stG32F9ey}&rjK~9q4UjA}jK?kfE zsfn>0S{UTP2W}2-saaE_Q-k$EG!DWP1e>y_1I%?1*{Qjy0YUuz9y%@r^Wq9&sg7C^ zSKoPBZ$5hs8}~Bjvl2Xc-Z&yOht+S5X*d!KO?GZ;^ibb6eKK2gFtk)h3I)8nxVz|G|F=`ckpX4K9=SRi_dfC zv)yVV62M^pYHUn-f4LP0%rpl!5&`#j!otGt?(X?O7m9j=DcWMKUR(eS!tV3ksdgE) z$-9&e=f9Y-M7$IzQ;GWj`VcdgL(vGY!>j<{iS}a}CNPh~(24m#<8^f`fGOm(^E!Yb zEC!IWfii*+py)nVPDJiKwJ!D>43QQ;-n3AlpNxoilAVhCwms^=w&4bB&0x^zK11B0iA9RUVkYapbRxswx*JD~kpQKwxq{w2X z(m)~>>ZkK|(M#fP0e+|G?(rYP0A%xCW!~G@*HAjRrk@~BgmT#9e||7UDklv&bI-dh zzpQ8>9hB1!A8ryCXtOJCMYF@5-x+I<&%=U_ZC;8P_Is{GY_?2+O-)pbcDxaawTz+< z=(5`vmI(CHC=8IZZ(@U#g&L)C86aJpg-SkywTqipPh+MERae-i^Q9u{b?r0vuD#Hf zZ|r#l=9?Z&fiS0^*!b|%YQF7^9J(2s1gr-rWdoJEy}68aw6~}w6WTG7Lvs)`yPrB) z|GV>iC^qm(y|B6w=bSA#rLktu+?EdT7np!pEHU&3-FxJ>hR86lrf-TtUB-WYsGPZt{sE%kjp8 zNutTZTznSF1>;|tMCP@o!-{G5L=_Vrj!Wh5f-F#a>q#0>xkZzEcmTu|p?vm~t z1_n4EzkC1o`k%G-@BGg`XPpPf`-#lVV(xox<`eJNbzN_l2+TRaNDDM7X}E0{=`imb z>Ruuk9nM#ifO87Vd;Xd4;|E`IrQ!;RaO_&|xaX@ZsB}JmKtY$vq0Dsm^75)MAA1$W@%1GB z4IKyp8|$*ra8PukHjtHfA1U1#}Bb^a#8<{m+khn_LErWUZw6ojrwFA7Cz!(^IL+g_`FozDO!; z?&C~cc5|ww`-8!L*T;+K@Us~Qq1(fy0u)tNt#96}GHm^>D!8hM%??AjO|Z6FSIE3V zVui2~Dz(9oaXeY5+R=w?cge$nhU1YK-R^sk7sz>T! zB-KUiR{h!+cAUwKg0J@@*X

3iYW!%q9y;uj+QiMQ*nEWQ&Yc10OUQRH?B$1KEQk zpOvJyDFfNsLci8-g?xJ?V|(u4;Gk|)o-jT>{__jrvGWg{hsJpn&vTPCUMqS0dPT%- zohVMP_$2D75?2(Zyw?xPYVg(g;4NDTmz0#e-iQ>w7JdV^yZ7?pw3e_NP=&2;7ts`d z3S5(&o$p(UMWwHBC4fi8Sr>RzQ1f!XelOqi5U+A-?))UxZfRx3$U5|LKOr_2ALtSC zI}@Fr@64O=a<(lGRJB3f2so-&>P@iSpSVh9_s&6Z{01K+!XOe<=`2pU_4R6Emibj)GYA;+H% zYoPRXam{l^hpY^DTZ1h3gxPrl25W4w9R7sw<@2-3f8GgSUXfQ?f0ULDlqhLv;MZ%9 zb?BJ?JaN+q#dDLy{R7c(M2-=!p3uZgvbMt#4F6kiy)eqR^FeN0Lqc*`JM#j&kN!>b zc;PkCx8AxXisk%oXZOg;zQkFB&-%H#=1=_HFL4;pQ9gI&3Vgy<^|d#jSXjhwinYHW z0D5My(oroo?+Wm4;7cfz*bx1UN;1i;jZp;6BaZsWd9rpZ|B8kqVw=YChI#*s;qPzh zB)JK2KAe(7i_c^BTh}Pu{0k!7cXpSltlx2!o)V^g-K+jtvH2x_w%m}h3<8b!JAu3& zY*ceLIv!TG)@_Vw^*DUNuuC>@GfzKHt22}58+TI{YNs3>XB4qon zxJ6EvK?0GQ88jA#ZRgUZ^#B9Q>6_lysY;X-2(Jsq2E8PahTt)o6WJ0ty0Q2uTw1~a zL|*Jog_e1!mYv@oxAvsiwe$%a>-^1Jn}GCh|M}Bd8FTiVNK{fi+($x>49K-Eh{uB3 zM!9kgbZV@L_2x(O=8=I6c^v--Z6Z`aI^j zwPDn!R(Ues3Dq|%BN0>l(ZYpmd0RPmbG$X8PxC-F+pyiN$nLFO3kn0mW*iSu&Z1|2 zuDra3G7t2A&DoZ7a+)Cy-9_KaEuWquv&4p6u9nJL+n8|7e>pJXnw08jumFjbohEy9 zHBDJ=kLJs%nYuAbJO$@SjB!{}mC_(=m0lh+)-F~n>Q;Cw_rDTGu4>KpujvSm50~K=cs}1Esj=sVrn;HBr zEZ#S}w4fj#gbC#eEhR72uJ`SWq1^!E4)<&!5&h$P__=(`a>n#A4q3m9hO{}U%lVe< zGOcDznNSBGXa@ra3XzkGi_X~!qx$uCd#rqdRBhjcUY}^GtDkHlD&9~LJ{kNo=30p> z)y81p`5mlkr!K>Y{E)y5(M1?O;mH86UGsb_7aAJ*rieA1Bhuqg%=6*a+jX)al+W7cRg$2z5B`(zME&;^xMxFiu0umZ z@!#t=L>tPhIF?;iFs9!{$sfsBs2_jO?tx-1Z>myFSxs>=*n{`a?FMZFJ7xv>&nk&I)vK7 zvoj1gevC-BT1~M&xghw?bmyW`NCtL4QX+&Qz+8xSV^!L{|*N3fk%O&8ZzjNLtk01S3SZY zj>wi%CALq6E@i+N6-#30CGbqiEZiibfYk{7Tu(*Ybqg7y*~4ShjRpKnXC!j_8`P|@ z_XK=h`^(Mzs5s<<0TJtuW5+GFA8{1wHMz3u&B(t~mPBO6bA@dtxGmt(%6S&kxxpTe z8jr6V95tb=D4N4;p1h8p?Lt7F{WrqsXhkiFO^N&rSsvqCf1Jv2f6iwUJ_Zyk4K^Z8 zRY~Qqy{OVD{SR|oU=U>^yf{y&624(hUm~KPL_E{NLgD4gwNYzF* z0&KC_y+2wlU*L25L3ZJ-qX{&Z(E@cNa5!NL!7Y~-lr>cYT*%@lIlwg9+|;*Ru`mi z(-{XM$mT7So2XDxIPmVcJmFKSsV(+stgS>j)%O|R!ksghIHZ}MCnb}#(+;DO;xIWj z@>gea2rv@4M7JIkDM#R>-s!!}DCSG8_%{_x$TZE95fDpNHI1#sP|vSN9W7n@Gy=`H zmM-)~V-a9(=Tdmfp!m;RHhp3*&bH8IIDai?%_N96Z_<{Sb^GhE8Y@AFDUVnPm4lQDO$)oKg5JR-SY1 z&!b~7|Fi)fg$@4l&(rczvr+!sl0{3?)>ZQBWlsKC^9rh#Z_QipTMcLc|TUH|w9-Qxekgi6>MLQnXee=18L16#rC&sD*&EeK&lRHpgVE6R;p7)k?hYx0o4aMvP4tVQu3dTl3 zAt7_CUmpgp4BEfx&(0RW85Fb-y#MI4Q!q|bY(9zl$fppBe)?zefTv~y)n-Uymyktb z-P;MbX>UihftmyI=+90>F_Ms1Ed?&_gz? zAAiW&!}4mQS`;-FtdZC0AbG4r^+oEzSc}I~p1%BSOuVhoz@|k3H`9#aT``|{wGtJZvGks#V#NMFf8~sD$(r*Q3(daX zj1%K=m9#7*#!H!Jp8Ax}@?uLsMdn*X#Fh??K{{|WcXh2D%vCP#Qt*vs&p|k{6H-s2 zzk0#WOzvOiyW{sxwEn4zl)#7mnB(iq$r!w~xT{tMaaM;4p!*qujl8%IvKPk?h}-9V zh(hP_(yUb-Vo`*2bad433YDzoJ3lDnFsu9(J(l^E<~Fz?-~OnckD-x8)W9C?SUo%( zVn)fi`1K+#Fs^)Oy^4vquz0+D-0{agvDkEd*wG3R-s}?RrO|`cn~G{;lYrZI z9uF?^MkpDACWM*)(#R+YG54Gq`T|B>OrEsVs8ZUU=DR0NHF1YRA-DEq*xdTd9v^a= zNAqvSpZcnzwpuEjb*~J1r6<@DN|bU0M+Ns`Y`GmA^z|}e4yIF8ENtL-stk@QQ?5qXF=mk23Hihj-YijjkT%SW!bLSQW z8W7@%JK`NSvWwj^aSeOz<&Piqo_6AF=*IXQ!K&7E?Sj!%jH8tZ_oE9FW zDSe&3cRsz8$W3{^i4%AOw!#v}V=b&sXeAZ#sW*Zl@e{Xb;ig(q)TiUG3;2FH-(}^i zRhrmipNvnwOl0?FSxITq6_e-o1NqWWGilUbkB8M&hsBGP`EC+Up_Kj+=ETY&oP1{e zVdshK?XyU>td~ERWtCh~)b)uojppi!?8GLss9EXm7B{+;y1ZB)?)x0n zK3z`7#|VDACPBo^dTFljhjYnOC#Ga^Xq>6*;cRj9jpp4C*yMg@((ysDO!n{vp-owx zRiW^IzD$**c<+*UzsbBcq%nS;Iaxj=-)M}>=h8uL5`4eB8LhOA=#&JPN6(i($9`U2`mtsZr;~xg$%S(727Hl)R>sAe~ zmxjixwNB5V85yeBRleTtGGI(DEwlAe>55=KW*oZGT2ObOO>U^eIXf7N*;zW3!MJbu zM$e*~vZ|*|9L)l(2(dUw^T-cpyTRrvfY;DPAlUdq98BIf{6$tm{~)V;+8-)^DHSj3 zU+)LDy94z<&;C0u)%QQ(QaR3d|3Ac1yMJ1o_X*yKiA5(Q1hMEe{O;+IIP(1Rh5O6F ztPd%V?eCQpikB~6e%1upX`n{@{EW8Q(4GOxxr z5bLHRoJQZBS$wSB{m7JOr;ml*lvpHapU-;Sb|I8po!rRLy@hfE(^t-domrviRdlR~t)$gVrXusfh#f4J+SX-(Rf+ME%S}1Rwq;GsQ zc7~yHiBsv>@iDUtp`*Ewn;a)%T^Ht(&60O}GZMCxc_q8GReH-_J(*dG&@PPln~TN< z#&o0~FKjH;f%1;Zb4Bw02-|!gO`dhij*~bl9lXT-_{-}Ea((qVb=$dV!pkPX%S|3AMP!rM! z$F+2P);o2<@MM(V?c?9KNpHsY&6j^M16?1Rv#|ur=wBIVaQV`;gAP$yU_ZI4Y^U!zKJ__*i6T#t$ z?G+mlrPj>x_6fRtuHy7d(zU7Rw9#3fsJq$l?xk;5iLFkd#mM73p6INxfz@(%-L5uH zgx9r4U(tdc>YR>vTW(+!lXeA)3ZxiiE>Tinz5gmL3^g2a&(1K*VYPib6V>G{3oaX< zXzd9`2%SyIf7`r)oZYrWt_ zt`<4?VX8kOv*s>&(my4XtjuNcq{e(p32BIbVf8Z}2|v)riN5 zooeHI`I4j_j*WmHxZI zOZcp&Ko6T_#xFZm%3;dIoCoNiODf5hqYYa?!QwO<1kmcA87sAGuJqthHMaK0y>nnk z|5oQsD9tlCqwsXM!&*_DbgFGfw=Bz& z1Qtj8mjTKmcOUIV5D|aTot&PweW#WdY5&S$(*JSjVwn-0(Pva#sTtv?Ge{5vf(?Iv z|4cCFWtt`XzkhuZRV+NoqE%M>y*w{J-zp8q-DhG)d1whUYr%V@<2uOp4*Mwm(*Ah* zxm$~3t%`Ju@t_;!$Bwe3Jg;VqFHh=K`IR9@lLpI>`*CK^7I0GIQeC_XUZZup+9D!u zQiO6rYxB7i>eZY0Gu|Fk7T>Ci2(x;Av!f=Jfp9C=+ET3uNlr*-STK@8$Q>~UJ-XG~ z%7g3ONWTL0IqbPPjmLX)QRsqSh(Rx*WbmZ8I7xgNCK`If&fS~hMjVgQj?lqtD-VJ_ zw|c`aiZwT)ht1X?hdekuPQtpyXQmJn^(I`?r=d8bzP@nyEiT6BZCs+}_g=)olb_eJU8>zLq*W0=O9*m;^L$EBG8?qw?Mgw0 z<~`D)aeos8UX_0{g3M5VkZ9&$JCJ=9^C91kj4E{#lgDUyTu1=H<;P zml=T??xT70pM<>Oa^P09lY_oJk|R}dadF+x_aOx!B5=#^{;{J6suSs;OgWBk6)d)Q z()&2Vd1uke)s`y0(RGWi7Y^wy7&(iTMvES7O5%d=^^Fz2`W~}h;c`nNq*r3ozn5C3 z#**q!>xdF1v+PAuE3WvCX##g!lfEEQZ?XxLA*)%{?+HCgwi{HRKaVWE0);sXlc6N< z23fdC(?awQPCDcW&d~72Dt-IR=5LFd zb1A63lailpQ0s&6jn|W4BwhSm)mX08=uY>7WSmdFY%_%*lbus`T824Dro&c&mzi8U!Ic*3*w%>;zWDc>!$pge!Kmv^xRxG zrpCyqO=_kre0jdyBAygemC*{^>50pf#Gg(J@v8NRl(O0|GVnO2G`JMJz#tb2sw9n0 z<&vXhiW~tG>hGs&5+B11JUs<~G^Zz2GKH1BOIe}C(@j@ykX-}A8i>@+w}$*uQb_yb z7=!FuV1YCHH#1-be-tA>$VVN*ED}ayVS<^WTBbWBTi@us-@WlfK^sx(H65kRU9m>x zC~&%xqltJ8nheAnfKU)ms(r)|G|HrVC4TU4In^I6@FTdJj|CdK6&O^<;W73*OD-!@ z&ae+Zl}J==&a*ydR&iY!w_HvwS{aO7RX2UO_BoVcvM@P^u(mo@9La4RJB6{@aOy-N>6g?1W}_PZ>cW&`dT(v#D689r zsuTBY-O31f3T^n&SBH$U4}M=9lUB~-1iV^q3=o?*tt2{7i&!vOld}^gPUd2 ze5%ltH~eL*wTIgk+!rLNIt&{L6s7}*;@$#iz$g%CW#Teu+7fJh#YBzEJ82O1%%<~C zg2DM4etszQTOot1k`iafkXo*OocYN==WyD=uU=!%P)ky%Jk^mo3Su1nNsMSA;m`0) zH@f30uWev6s9kT@n^o1xJi~5rSnn>`#Bol|&M$|cj7fWChV!?9@zT zDID5YdxcqS6~gVq%;tWk0I#;iQT473wBnr_f(fUy<62#@h?(SFrs8vv=11y81+KL} zJ{ET-a#RjNKeCyt{dINTn}!k1j#?b>^-}>&w|^YhDyA7rM}ef6!c?)T1B!*zvN?>_{#@KwE%+Gw`yYz|avEQP9;J`wjGp=uE zK*YtxQUBH^z!d1%>V6hUYxUJvF_K)+wK2o{CiwOe4^Jib{dWGhyFU-|GcwWBQ+OZj zX$g_fx@cM<^wCA_t8Gjt=nLjpzw{|AmMWa_OTHf-KL?i)1LHtKJ@e}TJ;_;ciy-uL9`AFV%4Z9U@=axMGNvhnbXCiV6bXD=6HvT*sM8rg@fm8Q>3 zwfL0a1Gv;LLJuRof3osZSy5w4z?&rUx2W(BR-^PhZ+yBpZ)SBR)v-~f21G!_DEHD# zf9aL^kZO4%?@z%LX=c$PCL#`$E6ahwbG(Y1Am_!hRx3yxij~2z!LfNzFYU-?p8P|^ z`tf@SGqn2>At!A;+Zs9V`%9APcSG}2n>DdnkB`n1@nKjrza-%28%?mvg);9X=1SQ^QTx%_eEw5-q4RR zJbF!(Y4WI$%%XqFZ}(Ij2B(U-0c9nyOYrCm5NW1Qi$tva^h3ZVmabqmORNfi9Mikx zPyR75mG7HRA%swf*Z|#Im)B+P+hvgrih4o?{ty!puU_M1Rl3aCM}blxcf<9ukvA8! zg?gWgpZ`0jV@bLX)AaZ|V^NayP?Z@^HhkTMUsvX1d7Qowok3`_=uIAHp4U`fjQY*? zONC4!DaTgKt@~@7qvhXADG6>GSm#~Bzw8(zd*xAXT}pe1Yy3$|5Km5)vB}4lszdng zQdWdNvu1wTmLuQtqPtG3)nin!ek4Yn3wjN`C&c(dySY0H-dUY~7zu?f$3xh9kZ^F3 zErQiZ9k=2;+rdS}9sKS?DWe)|`g>}mn`^wG7l}2^E@pb%Xf9{Xl^ZGNe(^L~ziys( zV_6I_T`N00f9-Bnbjzyc;1|C&FzoHa$dkSX`=YIORDZj?*|&GSt)n&F0` ztBA<1K;SG$iZFQDpzu^oN<)M7nq8P}R5(}PE9JYh0-wX|uDzL@_67G5)p> zqPVJqV^<=5M(J%k)}>p!@kj`XC}hn{xH&Xjt}v-gO>~Q@rP8hc90qrW%jS zZ{7WYFr!??^ylh(-rwJQ)(eq@Ch`Kn<-CXA8bQoyYM+;3OZj0v-zw*Q-t8Y-5uqin ze`L)-_oFM(Gd|n&^^QvK&|#z7hpU|{J|Tgy7C&k)v{H>kwf#qUURRr>7hqlKc)#Sb zN;Ej*P?HT+fP^2P@aKB@c+(^|MF?j6tG)!o^FT#?=u656E4~7Xl!4fSbk_R(?E7+Cyvet9By&G|<$xI=2>@H+$&g4}&Ep_o*?2Q!%LtnccU zL>fIVHy#(Zw6`iY^1jURQJXqV)28@!%#r?q*1%JNR3%^@V&j}QxR~I1P5)bic8 zgT`z_#x5BJ$-?>k64#~p$%o-QTaSbSKiyX4Z-Yn5mO zKEi>+fIiVXUBZD&K`rD8>R#OIr=10HM&5T>+cO+Iyz*EXqZApD$vrbP(1OiwuIr!2 zYYQ{S8yjsCFc!(|bW@ zbfX)*Yi%0L#9&T|-R>SsY4$#17Y1s`RfF?6PQ218o%rlmt}Y7;i(h$qWQfdZ;?AAw zW0+D>iuYHC3yRNXxKomnp5Lh3CdvmQo_{Jvki|q|88RTz9iK0{-HXZ7x?9{6!xj3h z#g>z&ckYQoR-_kcx*VncriZ)JL3EaD-#JZ5ALaIwA#UF*p^o{TywK@AXX8^b`aiBY z+A{N|N)wAP&C!ap0&ms*PW10obKECX%7t?~HZ)1upt+F)d!6^UXk- z7z)_{*?C`7U@{^!I!{l~#MJoJ8@}+AJ?WJ{ih_5x)W4oTvhJThNd(=2wwyn1R?U1N zzvPB(zUg7w;caEl;jo#}2NwnlX<&VvV%l7KilSO>C)gp>k#opX%eEoPRXx@cprxk! zp_R&xd7~_xx1m3mg4nrRPtD^KEgdD@No%%-1x(CPq``;)ofzWnMy2Ok}$ONwLzdsaQF&=lfo>bv7|Gnb>{s@ae^gZQi z(4t{cR}lF(N9GAW8@;E}hNauXz1Lsg6)i0q2wi;-b{Zy+8uoI;4GawGyl%OH+zJa5 z6SMF-{0{h6PWGl4tAOMfXx#giAZLFhL3g-p$L(P{9ZhN@Nnr5;z-uf^42}2g+bs)u zw0ya=c3_tJXxMf2uPlI3AP-GUNjbj(X&(m*wKyQ1!F;ify{Wm`4F=4AoM?66UVoHP)*=g^Yi6kjKsmCK_6d72--*Q^9G*x`)cUdosP>=Z znWMC(mBg~3$d&h#=dUl0kAggKjGj7N>83lzn;vgK-QR3E^6^n4X^D-TUpY5xkJWn# zx;1UHOU~BZB7XnQxw=wJlYFhZyOA1AKm`4v}exrVBvvw2GPM-PP zmNUh4we)i0;C$r`72i(V35S!Xc`CFOT02nod&Pabv}q`FITG$U5>uPwgd>u{t+I7X zq|Tb@B zC56Wrn4SJHz;D`{`D!=e{<#(SJp2y@+m_iaPHYqjojsdem!g+?<~-ht*{U^)%GOO4 zAnF+Pg;yJgVKcUYDl+raO+%xOn~P8Qub!lB9xv0Bc&@2dhJ@v2hE*7N!B;_FF9Y>C zZQIFyG;&>A6w&we-BrIVFk*8G?p=-+*9IN-r%O&Rc9pbSy~%(RY6Gjh^lh+ar~We~ zP|)j2IstjarwqXrv_4;B!GF6bN;d}*9JllsAnq3FuJdqfgBtdlQCHoH*d z&2|et9tn8ySB|p-2f1dmm1a0Ja;e_{qub(35JHpcm(*CA`p9(^G{tSslNc1l?)Ru-N z$k=l@w%M7_Gl;N22w;z?ZVHf%AqRDjZ_kFCSZ?OHExx0>2~8ZH%(k&$J$J@f4B66o;HI^-48Tkb>2k38-f*^fLw-B~10zx-5_ibT<< z1Qj*kdSEfVv!eX*>OuVQCfor&Du+c5qj$$4XVpt-0lRUVu{_LI5&rAKp)IRU$*oSD zd$%>Cgu3b8z`CWH{X~dlp4{0h`B&r;oq)S5=}PJBOH7|LZO>;ofP=Y*H(c`;P-S4A zI*iqM@8T1bEe4QODk?^oZ+LfPsA=-zvu;!8)kB0t?bpRqa)A)h%e)D< z#$HTJ`!R1w@G1}K*<0B2CjP;Il5mvd+j^s=nh7FMCnXp>@ar`T zHXrq7F+s_enIhT$_#~d^TYBLqh{MLPuNM9^_ySGvS@<=1M&Wp_Xx+_YHfm z4`OKLw@9Gjq&!j|-aNKyc$t^JM)uTk7G;VY3Gh~FxSOE zOIir1b!A(vKywXSgkAR7j)ft+1VVa}jW!m!hA#?40i!yxos>EfBS4yAN$dW}oL_gf zP$VZ-tke8h5H*6m*Yh7|mh;kQnPeE1smzP7iU%W6l9K8ZNhUq=kQeW}`fYoMhuxkc)l)}WucQt+H|sHMK1J|eBbk6-5T3HJtcJiI86b1XsLAjzWGl+r4(@H$>L)77VAQLyXSgfRXy+h z&Go|UEZDq$P3Dk&LCp|zW&A{5vNEFo3AsRT(iZs38#BnlwV3ZiPT{%&I4Ip8^&c3O z;g8e_s-_@=Qlu2&;bx}wOd0PyCMxPHHYxWB#TzXxt^4fw{OI|%f03vv^D#sR2M5=y zgCtGHyGp>H=$yusA+bNx|3OiGXYbu@ONxs_Fh3oYy)FjkuhiZAP^y8zoSfInWx6^o z)9a>)4cMkMvuH_N*XR@>M6hQ5=`6>xHMhW)|AkAfG`T<7-7PeqDC8kuZ1(BACaPcl zze}ik{r01Q_o)GnGVtIlZv+vH!e9I#sm@$d%r_ygW@Grn9Zf@X3&R>yJi9Rfv^G*# zcgV{qy=Di#EJ-!4{4B z&Ax$9_f6|?rjGP;PI=Ay*=yxu?IPUr1V*bu$Ci0@EpSgzx!&@zc@BzgqI3PU4A^xNJN!~Vg$6kE z!hAV)aMz`HB~(}J?e=bi=L8~`XR%OtEGZQ<-o8t%&~gXis<-7%ctY-H&E;c*UC8TgTybe=3`q8fW zL7(S-QT2Hh1H8!e%rNHh>0+7Lu-%Jsf&?@_PuZ^fr9;M>RQA&U={r zU6a~xLrdMqcCI3+e*Scwf>(iIy4-6a(az+RX7#=)+h2ngvnz^=LlRq9J&!AosN+_< zVZr2VPdJRS@$*({_hR`5lQnVWQladUr~K*0Pe1L2mAnALGTO>J=Av$cWf%UnX=9Ra z^$(`*thy5SLownmzw1V3euy^K4<5s$(%D=8&y>oP1*ww@1efRQY8N+wHb$7S@xxMG zH{6Re$KK00c3bP~OWx#~Zwh_Vk8jNDr9ZHi9v-5*lv#!$taR{G%o0B_>`uPnVY<;L z5d>aq1G5ZLvmdF=%y!YSzt{EHU$*~_7Cb$?$fb;Vz>~SoO!wCMMP5nkJ;3uzVDwGA zyReR%QQJ>UF@ycEXZtOO1|JOv?HgI@RpSHNtYp;v)NGAPvGLf#Y?j0KI0)mlG3{;C zrMWgBz)=-OA)L-RmE6{LRzwx!YB(zlYW+4Bhy%D)Fy|s2_vvpBcR;Wrx$y&@AF>|AdgWi$1}piz74x zA;w*83bQRzK%llPIG8S&(qqGxLy)?Y^%^h_%2cI%xvl_A!g7$I-M;+mT{Y$o_v71l z)2A?!UobV~E23lx{&J%U26c7KETb9dGSM;hxxs6L->jrm<)0@cWzsM4xbG|c(Mqzr zsf5%0Ost-CRR7{P<$Mt1@QBAnf3e5b{f*o3Y@i&LoR&tf0-82C>06reGXGF4-|~fT zgOGrm>oqbaE#5e-aOIzS1`4ZSAbl5J5!1SL#L82e+GRJ;%1t-@?6|spZ`EiKlE?+k z4!W*{-Q*qe{TG!Q*p8I{*i{A7s+Z@4UFtQ%@ME^DWA_P%c7qs$%(C+uIp5;w`Xa5( zY`*>ffK#o+BqheKZ7nE5RQx&0GaDI-eny&wbtX3tC?r`3(xMD??-~*2%z0U`r&~BmlpEOX3-Dabc0`x|E|ogv z9;y29y4v$tDW(niwnIcDBtA~rK|l9#RiJ6*P8EibgEZyOolYUK@}}1))FJKiG?#yy zxqLjAlTkL9xWBt~qcB{yaDF6}vy^WYjf+C6-QYwI<);M(g250pW!jF0Ry=qaZ+D?eFNE_<``iEJ?LI z)0OD;)9g4jtfYC=E74b+g?hE_?9lxg;}fd(FB?ZC14Ofh1mc%=-ST(;0i^;t4>Xd9 zWj+>+3jjobTOJK-jivLwt%4QRv0k}tL0NN6OOVELU7oI_GAJ9)V@e1_dNQf_y=t5A zbhhVu!3ok6IeH;%*{_S)M=}hjX^*o^7D{(`4ygRHRdH1%%+w~&Edu`sWJ)sJ`=0X! zaT4F`>1|M=Kn>M26PXWn5zfm95Cjzb>fBD3FJ22HYNa*Sy?hF}8yga2r@~OkxGZhWQ+0ag;CN3{vf zb+Gzf$J?Wsm1>BDAC}#J0;`Jjd$O*TgoxpKRY!p7?$0ckOy%EfK{IZ#m)>R)yW?jj zQFSvAZ$qbOhgxgiAE8XL2neJ?U&dHq+)Gp-V8zAy#+{gFTS#_Dm6CKKPc+jeJpHVf zWrBsB6#H3Ea`nTxPdxMMndc!|Ajoqj2(`xKaO{Vy8S+Cz;_=~&GnNv*H1@-yL~h7# zMv#cl4c*WkOXBUu#F-}_>n{gNCn}tRUqP^QJUqSSvHy-vHBL*vh9X+J9Q`p^4)R@b zyRAm-LS<71>DKLHPFeacis=#HM114c1r#2MH9$MMJ=+GcH;hA{imj<&sNx2u3Dsnb zr~wq?e)^h$c(f;DXY+T}VKe{jFTT2%<216BREjDtV)hH8z$#Pelb>8)aZ579saMxF zc(35izc=^g^=(ZNt+w{zdNu2)T{yk}$*a(<)G8akWRap{{=cSG1U86Tt1}dSuw@_i zgJgPIeu?R2rF3q3iy|r)9}gaiUz_mZ)2n}o)d3acK8+MwNxe9H^4U_%`$5I&ntz`} zx*?+dARK{lZB%@sT|ai72uU^MTm*rtsIJ~{FxFr&jU$-Vh|YlSV40{~S7<=B*r-Tt zw3eg3R9nEST45TSn&YQ1dq3!5zDkPHll^`SE@||$C?e50*r2w&eD#h&t+bf4f!kvH zbqwfTLi=DvUyz|JH|9$ZyS+W%?4y&Hd6Bf0@wV%STdQtFWdG3Mpm|SO!mH5lyF##& z7;ArvF4teIbI_)e!SMqmije8!W!D@dHpuZ1pIBeb0W+S| z2s(PMxz)5xu20^hh=2-l`fmi_QFd0-!Nj&CZxCl@60$pyPl%kYaEUM3GPw^Un)Q0z z`=I$`uV;&c`scRvQS*~pgRA+2<+1u#56qdl_sbN*ejlzLUPu#qqB3NhZF1ZXM(y6W z^HH2VGUMV|a5;*6qJw(oAf3R#oRLCU4#A%-dpHSBgt|SwJsmdK&zR}HfEDHL86E;o z_8a-~!YG+iWt7%1EO^*Ttc(DY)>)NgdSIJMJ+qPDlXSdVsK;!3)n8bQC62fFVvz*~ z&Htz`Bg`D9QIt+0E}N=loyzs-j$4xEE*rIWOZqe#Rad3Q<+E3B%YZ9En*XV&jq>xd zhe+bRk!)xqPs*?#Ttg$-pv(8=BFFF8q$e>r!mzyrgHoi9Tb`a%!D{~tr*gG3oa|>`xf?wP(*OCr zz4B{>4%&;J?d+{n=K(>>diKMXy@lWYC=02a4lglFQ`%cY`JpzEEO&2DPIm#8N^q)g z=>9&xvj})odscYfTXv~9asNwr@PNQ5S<;P*4VOQaz5Oq--d$m8t_G;Me_TuC!@sCh z=}b8R&G3s{Qi>wxE*WLmJc(MF?uVrYXUu~W?l+W&O`YHHY~3|JHGsPN%7pL`_g}u^ zY|PcOI{qSWjk=_{-mUbw%#{D~r6gW@bhLqjZ0i;bdxXv6 zmihl6?Jc9)==Z(v7S}>?E6`G`xVsiFPH}g4C{8Hu?p`SF#frPTJHfrUTkhf7d!KWi zbMJkx^{i)Q^+lK@3^QrwpWo*zDaI&eX4!I|wz&gcc%Df`j%9}mQc|}^fA~gjM~V(d zUvrc3#aZ6ZGy%;zF`{?%)RtrqPf9I2&V@GfTM0@P6uiAx}B~5AyWd8jCbjH~HQ& zH{?h}p1rf7-^d7&o%V44ZdocC# z(;9_#5zT zSZE~tB0vSm_)kg91{r(imo&SYiJG7KUYC1sBZ;_taT#>70YxWJ+avZDBXQLQfYY7i zKY-KVoM!rP$SI#5CXS4zi@RZ~4z~-@S=HD?qnC?Y>tsWDdQ|&V-^~O?HplQVo$GV& zHj5j5E>;@&#ObK#fxNqCO+<}qQZY_puAZy;;P$+(FX~8GOrvAK!aIbj=P(5Io>1KMMw6REhCD5t@>{IlVhr>dsm@ zj0FM$A2~fX(7RHhv&5e5E((pK9~!;JM4!F?tPVUx+_~T0%82mE2}6#(f8Q{|y)`}A z<}x+E=J&2pHRPwnjGL@71dH;9$%-B3vyFGxH!fehNn_OgY4!QJg#uI}P|sQag_5rr4vtuJ&CaP{CiN6CNb~byLgFS2mn6Qxd3YwTPD9(=u5T<5LM!ar)Q8OK^pxp zva}Asq&*-JSi)>PM-=&`TSnLWj-3DLa_Z5nZ;1c#^z*YFoo;J73q-N2FN(wv$n68> zLsR*74h|c`59wTvr;m4zFLOPU$B>u@AhoF>yJLDf{^q5q%Ji{_P8*Zy^6T}KoWPmA z8Seq`P5U3PiH?r$?su+u6kG;)ZVne1>b(0WYm!RJzX1d*Arv3c2cq)b{XAC`%++M1 zrEeX_3l+$oABVF9NPu#xqyWL54&dzYd0fMcfRmH)JTFE#=9}GEa1F5GDX*T~7ZI9` zF06Na!Q(Z`KD^7-{o={&KhOkh#w%fuhIlSKFRLk&ne1OZ2o9Q{$XIf!mimgL*%*RH z^%`uCE+j;l^IMdc$|WSE+W(6%9eQhrXU&o9ZhEm8vhwmC=2~&Vug-N&A|({NuMNos zefV^4J6#UO1C`!#z(qY`wZKOu_04ertnA}fFegSRg!UfFo@q?JC^e`950s}**R^vf z{&^|4Tu$k_iZCGI%g)*oHjfd2(Qe}JsZmQpLZ~U=VB0IyHM{#n4XECqx6DmFRJ)-;z zE7jI9`-MoFySgOz1;ESYVldqeMR3+Vv8bz&>iQ4^6pj3SU)?zJt=*yqud&nlxg8VU zX2oAXiH(g8x=mL1eWkTFJ?DedF-7|F73?Fy*Qw2-ktO5B`HF>`(=W7YT$p9ttO+z^ zBz#@|H~D9qfg=Rj&DDI7M1OHrh6G$gPU_Cv-(dz zvhEs8Ja18X58boV{iN)S&CS0TAEjq`pOSyTk5dCzu&dBMLX%OV_O4B|Y#+`hFi`*} zWL#89o3jW|Os6KtcZ|13t#!c&OpF;+Z9!BFdUkA$-^+(|T3GmImt#&Gy%XqW`mq>W zp<=kL8>SG3Sml(5F>u@T`}f*p-j{TBir!YKgirI0&XzB(rPqA!m+gR6(fFQ-*EJd$ zA^41vA`S1G!K1+47G^xPl9)|FN)Zt6=j-Qp!T1c2-%aT|z|Hu(*Z00uqZSDnhgMKt zew=VpzE~HSPYfn8$f}1j+0n^g1yq`Z;>gwte4-2GIe0N41+5A7Qae7?Sz9Hcs zhDGM*=U4k>gt!WVcxB*iy+-Z_qRF%^gjnfbMUGnhH%}7$2Tyu2IXOX5i{t?&r)spT zl_r-9E{RUg$=fMMnOT}ZiJ9dV^u|9oyy2j|S@7DiI@*nmb_}3Qu*5brNM#ziobcn> zQi^LA)HY)T-BVq^U58a-aZKlNPQJ-%b3)Fl0QR&-k?9Sq;6ITP@nPpaqgIm>K}lWu zhIFNBMxa`g`&)MuBE#rAVEzr`e%QF3%z7bCFHIfLjHAe`a5+r%|8PE?0~`O#`DE7m z=J41LFcG76-eS(W>y)dViMtZ_g6>ZA)2RhO+i%ls*PIW0zkBDZWYZs-+Snjje3^A! zU&^oKQYP58qN_R;c;MV@y9qvc@^rm0T3~LU?M;^*n&O!9$ki^D|{vTOR3|$;SiN?N9uF$>dRNVXs*T0R~R&z9M3*V6EQ=lq9yS=1G?NZ zmy^k>i#sifKXnNauwf5xu=8#L;bi=-rxnte0 zSkkH?w$Z05dI=4vRp#gAZ$|M6{BWoZ9ZqE?RtmZ4f zLkbrb7tcx9uGkOmSWHLlOynKNNQVM`|K{ZNxR!27Sq0)J$Y1FR+JJ*-s^6AFTzq^tojS_*xL$-k_T&^j5~E4InPv*5hkx-TS^KkL zh3Gp-v_jT4wW0=OGoeAw>u?WnN?p6Y^@a7 zTYe<(W~Qq;s26$Cy5IWi{C1qTtxc)bKg9EM7DWNXNJHHbAVx#}=H1kPkt8=Erssbd zoe)PigX#glC&FCuagSh9{b9dQ*>CfwTuS@q4{<5Gg>cT34J29qs z-6r{1GBv1n@C^XVeA(%m?d}>aj^h#q*!i((=89pCwg)u;KNkBe4zS%`!cl6Vms?f7 z9Xa&na3I1Ve(A@0%mO?+fJR)E#-u@bmLc9|IYbz!fZ|ptku+>`c^Rp(qnxbOTlYNV4w2=kj zfrW|gJ2s}kv$M0kxe5{nimdf0{(A&qh*oR$fuWW%U(+DDNP4mrbF-9&&1~G~ffUnG zHstXtM1ic#N*}|VjEIQn_r7-;yY=twcA&=CP4$1lBofDkU&hI-<`^%$X5+O4Kx;-v zWv`&+I-x&*1ZvA+h_wx=uxZSnL-4(CR$QF5y<-mI1D&r9W|nJzXx7^-!v-!ppd|37 z*m|+uXv|l8(UCAewtu@es1)nUx871})@#H3yT55ms&+UMm1Y6047Z=p2ZO1WFQv|V zzdnm44pIk@nrsbAy(e4n=F-CEo!BgF$QcurWdZY3L15kXR~sJcVdK9euD`jOoc_6A znx{y|8_fE6j<%;ZzB|u>57rh(o4uaV+FXvZB8z{6NaojnK_od8{=C4el8xIGk6SN8 zm^bg5LI9sM(*=Z9M_l^P&xv=8#}~C`<2Ljc9<7~2c_7`VN@PC)t)6=1;YL9|SX-Wv zlzeaPr|Gb;1l3Qo2Hz(q)r-`?fe-g9V-M>iT3XFZW1mYX*v2okN$6aP_Ihet_|}h= zBbe4aTz8>)x`c++lpukkB&7&^62XjiMsl7!Uep|S0v-Nn2mAHphfE$D^`2xyez)@? z^9v=?kRoj^56x9x>i428X`{_mdNYVQ(Fwi5&fsgtk4z@NwUe}HhC18XC`vnM>Y3&O z5%GH20b|QLn0gLMi*OcWxV@%lbS*1=2M^S?!^ud*6%*@4Ngq*5 zeuu2?WvdAO6DUChAdi^Pa47wA>AGcM>3QuSEQILcd$|m5SXEV3&#M_#ho>6zsY=T^ z`ri!=`rW~AIlX!3l45)_70`tjgV!l_?U-ScV*NXom+dqV0do=Ka~v=gY5@vVO~x`u zHreIV5k6A@A7MY>*3-G2KhbOb_7M%k!N>SC^#-l8HQL#;HPGrA`#WFY&cvIe`5L4x zKXBEdhsGFC4&d>5c5uP7c-Se8vXU-~4VHk)TY7(O`I9k(`uOCt+EFJ8lsvq##2 zgVKjw)Mo6mBu%7qv5cKf3h!Fq#>us=86nZ<&1F5%`cw=GfBD2Al~)coofO7@z^Pd7>msVotC#4?Rq_pYt1 z@&tod(G2W}2tP%R5o$iYwsy)e%F8}Z6%;}l$nmMf{{%#?{W#qt09FM93H~rDmwz`i zUD8e=;b5Wp?J$<{>h({vSE6C0NIo^rgf-~4QTUO(ts?}wFnil&YZ;3{EZzrbDR7WR z4OR#r=xJcsYec;UqGPK?AZnMz4#K|a;U9Z^&YLcECT+ZYvyrA0ASJ!I7Q7o1m1hxs z_^SSu!@ng*DCS#>C8$gN2-_*+AB1RG z?&0gaTZ7x7kI5ci4Kn!CZ|;uOiAdOOFZwxwC!WwYk9Qf4bf;TV@9+)Vi#N~1z8cf2 zWr_E$#jRJ5U#}$MvFixv073|ERVSnGD2Fd@rdwWQu*IkE>55{*$dDJsd}3gU!z45Z0sJt-)vTe%waiZJ2n1O)*;Kt zh3|T3viXgeWjY{K>^?7MI1J*>WD=P#1S$?Qq!h`=e6aZ+Z?KUr&H8`FKS=pDd=4@* zuf+8{R$NBX9&6@jt)#Pz!m;KiOBP)zXvz3|JekV6>eO3PvFL>H1;Pu1G%ap^?9#DJ zl6Ag?1h8!jm}{cJq}yC(sej{YYAR;Vf^6vs{u#~S9e$8M<*{v3`7d`>Fg@0h?BG<@ z`#8l)#dLTLRNv#LN+w-K$1cv+aDB;csm{u&zUzxSslQxJMkUH-pq;B9-=U}K;&!}` zU3l(R7tTiRgucYlse2*3kRuiM1`VEHP}$@U{r3@xuO9WcUx(uthq9c#P$RmLq>Lvh zRU6)(xI&oe(sJ!RoW70M^4UDU-rQPPpmq1=oo*jrcKj@hQmmZ3daF$p{64R=6!$Ti z1aB`(@vn4(jbZJhWl)Odm`EQ5+2fYApJrKnZMA4xJI?J@pI)nro^+kb;J|1M+i`q> zjMV+~zdTG)jwd^xk4tn#gW6ach~=B&06lz0lrQ+MY7K-%m`yc+2b?2qRytZ8!*(fq zVk$i&Y!21%6z1>-C5kg+6yz`k3sV#|hOCxILwLCh7lwa>jMA_^hJSbhEKPV0A=F9I zw=}Bd!?~`pf%f;8`@D++Jy4`yUSVMEN&U>ojlf8`iG$Q)P0M|vjDRKwfk3Ui_}d{N zWK*8Ze<)G;?)EXEzl0z0oHUbZQb+Ie&&2!<9tE!b-Ns}e!N?>FMAR%R3KlbrzL}kb7WXGM;x$B}Bw8a)Zt^ijR@~&q;bo#`y z6}Ulw+X=rvEuLOkgTH<{24p!p(Yb#$f&DD4^)L^MpZmatX5i+&H^?dqk9eYTBQN5tNX*L=JhzLJZ$8$YZvg zx#%hkaSV%ygUbd}4EzL5ug$9UJswCxR}^pZH1HdNje+rD=WRB7ap)X`LC3@Y%~RWM z$E{TLS-#gJJ=~RwI!rRNqwe{7!7EeaQ(feacTo50aQ$KWzvAeoUi zrNvPv2fy0q?)b}7P0nM7JVgBB>dK?$pZygF1EBP1xfpAQ_w z^_QvES#pC_!8X2h3k*4pqou zyOj5>hO+eMdab^OhqJEN2-EaK^bQXT*-Cb*_s9jd6RmcX3D0N4|AC1z+KY3-`Cnkp z$S|E-wsc6V=z@#&4~HU;E2n_vh?)@-lZJal?LMcIfBBZAxsi90;$vT%ifgLNGt6NA z6BOkL{BKaCGNX-tqbv&Vm8t)4ElZv|G9wPV@-GJo+IY`}CPs7z07P+ps=624({kde zr{R~wA1v=*yh&4E)+e#+nroR>obzA!tatYkiaLT;Wpu==CUAYiQ!&Juu^2nxO2Whjg)e1R~x;1_lT z0XdRyl)msGScQM^A>8TD)5!mp4E=)^;gQ2i|IBu9x;k*n6(OV7_z%{ezbrrbug>l$ zVs#-)2b&B^pN;-AkCDQI4c^7getUz+W&Ajk;3OyK|T zmudg$zaF!`{J#N_b%PBcM3az{ONx4WdQt%TsPAq8%g}rqNFX*YZlcawlhtZozI|h{cR+*v(Dz7P8K#{Sh5%V=?8M`ZGwj%r7qAq#4RoYI{A46hV!!(@HA*H&BMdO z%M-c%uD1y6R|lQtl4SqlM1Vk7w^^z?^?8qJBhr6!CJeXIc7bS?jrS7xub%ax3aG39 zU772ea6wRFI>I|wdLkrt%g?T>+tKXY$FXC%(}2(%9uX+|P=4;2cVUzv5R%?^qmMt1 znKw*#Tcn$xPs(D2rqj(IB$+S`OJ$r!qJCaq(&3>C^sW01Z z_of<4gcLSxr=MsrQz*MmHwHc@yUzsS9~uV#1Q5)ju@naLbx`|q8`&TyCqC4t_+Yl( zeQrJg!H0Y-{*-vh0Ow135~&Tc`nF<{^CxfR>{lFTM}BF zxGkb)SJLj}cxcyBMGB5MP)CHDruw2d_|u=mK`$L6lg+brLJ4oYh;!899D{0tNk)k% z$=|l{4HJ@O$ z=;$w1Rq@{ZTvsG3C#B>jHg2km2Ujq}D#lEAwJth zr8a%z<2=>qkmk$W8~9K4QSl)@wY?gN$cOEIx92jemEza5%?@tr)@nn;`XQfkb92{^ zj!Mfi6cWG4 zIFL}esq=JMpSi4XnVFHH7u$$AIX2%&BKsH`=s zsMg?xYP51cxa{y+Hq%vxD59ao6XCl$YdwVq47P1*ZYMPQH-sG*3@pqs?wdG+^=INc z7iQ`Elh6xR5MQP{wAP2W^j;okXQNKXA;bL=MvUKu4$77oE7W3iyS8Ht(OVojq*9XC zI&w(l7l%={3GQ3@VV;nVZa~;GAl}^-wVWa5inu`)JNZ^r+dn zoF{Crcit-1RFNkmv#`7mzeuUu;QN3K(!1DxmyCOMV-7eth``xzvj$m1@t5D?KRo-I zm)$cK9ndH3itr0-$(tGxzX?#Tcu!?*JaVYR>k&GC(_~KpE`IQ!vev|@V4i%-^QdhE zZVXjaVVBT;BGVx!x7#@u9cJl9LDYltH+4^_{tQ{XT6Fw3|f7HMxdL8jaLVv81 zYelm!=ZU>cyEYg)gQX!2#q~>Cg}v>PIwyOfuLE(zCb>XpPn{2Fvttywkgb#rh(!AK z2xqTK+%??`wTm0{f#rP!%T9$(Yj@`RA^VdGM=EPYIMddJj@1x~!twcc=3^LDc`I{P zS^nwM(&AD!jbUO(Z5{wK(~B_lU2l4CGiYUAO6U^`0>N)w_6k00p>Oe#STA6 zy0Zc1wBV) ziy?A{dIJO=*VK~7UZ;2G-aw*i1(4FNBlJ|bk|5(MNj=HwO|Kyao#$3RGnzy9NZ8wV z-svVc#+!os?SJ+*QhAH+Se!YzQr-(T#28O!!ySkP2?k?sXP z9&--1C1f$Eb-_GPRGycRZ%+KhtcWJWl(g~kxD7isLixapnQsxJ%3BoljZdmPwE2xs zbTEyy8pC#xXMZ(p!A~K*u@%(1deYDlqXmUbzrZTHz?H5pXmxe*|ly~X0;yyoJ6cBI8@f#+7 zkfiLM!GbP1>WdIrN9$KW)%zJ_d9CFto#rd`*#RU{B=k68sI|G?vco}7PX31-@$JvA zUW1gCwyV5TVRfPr>TT zoFD1u#AsP%sR}m30*hCUvj_cO$$S->S%-=OP`J~CTI`A(ih$nv0fy%8@u0r3( zhBQ%5)|Y)3EqQZpx34+Z-d#=afXBihr>2tDyamxalYPbf<~V-I1l-tJdojv`Po}S% zX~+hexOjvpnXdl;0iGMksa_U^0`o16%#9Ir@N{CmVpx2(jK7AD8oX6Xr9#wa8!A67Mc73791JBOl&72w!n^V`F=#N5< zM!HO*hK>i9g~zHv3m@}(8SCNgAsgNn9xcsn4F)5Yb-@!q-R#mdz=*X?I&Vv17|Ag3 z-_d^MrJM*5l0bu-_B=Nr+Rs0!;GfZclyt~mou=)Jc|=(ihn}WulJhbYS|zEw&wCzv z$LiQ;mh%slpAFrDZ`nNb6zaK|%U4a|ehq*0$(z@G_C4H)I=hKx0J~zFm5`OZifRFO zm8K$_KHIX9jNmO|YE@?ue2W>|wm959Ymi?2Gs&KK&h|B#8)lY{pudk=oZ-WtOVnxJ zx^zuKwfk9-%Ns&@Z;^W8HjfBcOQFlekQ_2YS~XxTV(BI3)IOwS%&IoUvTkpOXnMJ? zd8aZJBVjzohPVo_+rQ8jG!U7W!b}H4{z*yB2pfH%X8T}&OEN!2QnF*M!_PjzoL~Sc zhq1^2(i+S|6d69Z8B^r7)jJ|RCBP`>0$vmN9r7(x_*=Bo83u<;@puc(+_Gs;?|{rO z%~wHL;SV1;tPH>$6?z7iL!jv`ID(XJ7E}GhuHgXPrnmF;W|2p~U5)rX@+TDAJ_GL_ z7~m2_(%9VcGO^|BzEt`!LXgL-qT*>A3Z`9ryg`knZE6V` z-V734^prBP$f_et1_Wku(QZ3<{f}UiX`?HH_@ZvV%3jtk3K~}fk{TJvw8KotMLStDR{T(%U!;TL2ji)DVS$}d{O7HRsCg~ zs_^F-OE#9~9&FgPSPF9|VOv^4J!+RZlkAcW8Q#a7Jqdb>Zs(OLvAri?FKfPtmxNzD z=vVYf_RdV^pxyniEP%u~MROW#N861JA6!Ig353dbUZ=Tx8XIqGKNEv#mJv4@t zq&asyzuWdabiTAMCxA&Jmsob0;~<&CCn@GwGn3#U@7^oPtTVdpMQ3^5S4tip4FN?N zgD^4wGxyH{dq7g!cd9AlpCsz)>;P<+Z^=-1CDtH@x;|g!6?3|0VWrXZS~?lcfxvK- z+L#DIp9GDa(e^%a`44s&YHW=wyH%fepkkey4Czq3OW+d+O>zhXE%98k3^(`s@${fy%kRBGX@g2R>r}kDOFTxk<1h?56>^(}; zhjN|{VwNwxy*3vZX)Iq)&MFP+@jv3cX%NOm_E64jf?~{1WQFwcYetm~`)567&P)Z6 zTyKmALAD9&>E%m4;=&N*x~uBePHq?jo_E*6^OLf#rM6vOE$0_PR^|d(d>7#Myds|B zmL*b1+E4MKJdTU3ik_P1jtAG1pBa6>8(JdTqKI?NPlV~Ay7!X)U66kue?F3A)(ZjR+?Sm|$IA_;mtQ3zU#dwx z#*(ie&NlmkfzElMCw@SmB1?1*aj~;|r4GH^;(-f#JcDn&+adh;E2_>n_DdUo@N!K6 zJdxY^uEtlQ@hJ5R+}P%$W{T~>ANw6ER!e@1VXYnDA=UG2XD6K~^Q<2Vd%BelC4oa^ zy}w@hP;L!XU6+1$@J78eF|h6POptD%{ypte3PY8^&XKMnOsz&94h)$Ad-L&1Uzz&F z+I0nDW%03z{N7(?D1vI!_RTPA32u-ijz2z9rE&dh2{iSwT}y(gWzB0ndiDxvyGgD| z_NOFsGrAPevk!-?B!SF%4~OL{Tpecq``u%@j}T5d8`bO06g#aZG~-XIrq2&u4GYy@ z%{Pu=Sx$OA$tbdRXp*Uar|XQ(DVhYPRY-MwGJJ$*k?KmlIHiwm&3K(Q5qHrY&j^F~ z7g{r>9P|4l&Jyi)M|5!kseey8Xz1E9M^x{t?o*D-kfI`hQX5#Kx7e(u`qrR*j8S)c z@Q+3AKq>Jsr2rD;uGRXJB#WHMonowX#1JD$M@nY*QY?#XN`7_)(E|pO1w%zBGq>{Vm0o<+EY(Nh&V|rbDuL=-Do)7x;GEW5TBWg1nCEjMtO}Lp^ zZtc5~?@y{$e6A(t?V{UsH$i9tPuZ`oy+<-CdiYc+rfkGUwO;QIoK&enH4w-q;h%IF zg3+W8)!%|xAW!jzNn1+luF3(*xr?5g(ybwgF1XQz5+^#22?2>(I)c4al{h{>Drb8A zm6uo)Co5yuhK!G?z9zr*uW_tT&VppY3e$cx)mY;UwG<5JgS|Xls%OZTRW{#bWtdD% z83B+;KLrH9C%dno2AL?*V=F2cbeP)^nIQ~R%6k*o3&SFibvcO8;0&SN{#;Wwy#pW9 z5yKQqGJf0OyN>)f=FB=(d(Y+$gh)oJru%55x>mWVnj`hu;{%QQ_b(8<(-cTg$=2!fGaUu z#U2z-ov-M2_76|`9%9a~YIp$_rug$|=1ZH&0#9iDt!@cSw@(+BLwau?9M$eFeBgrj zcgGH?ICW>Yt0f*R>0Hn%A>Y_a;UmNv7(i>jP|kFGu)nI)+Mu?=MxLH^qhw1E)vkyK zIgpOUnV@t;CW(~ELt*5A>v8US^1oOzux>o##W>4CF-(0Pt;ApWu2lbL9;48|2o79t;=+0x2F>mnp3tqT8{dxAp*r7c~9`{r;*d+_;Fl~(+mg( z{;{HvkyyF<<AH z!z?#VI$A*J%>~PM86%NL7PPL|ymCC&bBI@sHeQQAFw{Sune6VOYTBZ0Ee9)1RadP1 zG^CloKe%s5l^+&zYu7;W$DmBYE$!hWBY&->Q@N;ok&jJFkbCa`U9U%aaHp%Mvn1LV zcDykqR7PHX}z*GlnZrM@2hvE||V!|O-cllr@_96TnQc41NGn`^7h_;e+)VTCi6 z0o4ip*dPS7r|Yd{T8nLq?gYUuw+08RBN1~*vp+W}(d)em z@aXhFxJUIlg{G@o>FMCb70gJ!`}cneAR&A_dsApSB%bhD2_9=QvZ0^6T7*Bkc~#+k zgZ*-`Ee;mZ`N~Q6BbJ8uvTK%@caF>J>7;+LEn~w`#FKctx@f;bx3+^xjzJF%(wt6` zB&mDNM;A1kl8?5`F2qU&B~6fgfal0HVM5A_+_pQb4Bi3A9 zI?gN$zCE=b30`8IM%CHRecLCV57$FrNx>5d;X0eaqUUNy8cNl7_PBjc#)?mnj%ax$G)Oq)0F7Qv@8B*a5&Q9M7? z+U+l+91*ss%M({J6Dy^wu$GGQkzu1+-cKM)EPNAb_HZ;mkh)*5hQndZKRvKokb21Bg;}+Q0KXTu6X+KYQq( zZ2C1;wD?EYn@(swr+mCr6Ry*K+=1g%U$AiU(4^M;05W=BS# z66440l-&B|Pt%^RwMD8|GOhQD0zi{hZD7LpH9ZKx4>>_D(eSwh5i5xy0_=bbAdUYM;dC!X6~! zY5K_~Z}O6KV|xGldJVQ3CS#x2(@eQN^6zjD9K;IQp9296m1S~=zHxaz#=dox%=yVb zGNMT;2#xi}o>jW-6V{e;ne+O}=t>v*(4upjh77{+=f&E8m2W|VoyDcl6zGcRN^8XE zV+GP>|45PTPwW0ktqWY5uAgb5n|dvzyqbk74j*E?O)!DbjNpDUCuYbK7IP2g9>tP< zwvt#oIX@$-d|$1dl~_BmamPP@e(PeeaxN7C7{tU6Y_VMN>)$DPzwP9K?7;Ib3(@ZO znR{w)-dKL5TTc=(d2Xh&=J%@(!G;lgVi6}w+Ry5(WKP_+1`CTxhx+tnKE3N#?M$9> z9D-58n#G&-DAK-rYP?A>o-Qvk$R4PV5a1prQn_Z|CipFfbNQu0+&wJszPT6c_CfyM zvZKDYC%YgTH;$@&DD!FIv331Sx4WUJrUu%yU3S>DY5tpUzvbu5XF_5f3E3!XJHMUN z-8c@A_l4pP`=hFDzFXNqTSH$Vz5B zrF0y$c!zkhf4u(>>D+I`WFc@3RhKQ2v-`4TgxQ0djou1aDm1Ck^c3H-xE$eRQdy}# z*#<~%Has3xf#byX-@SWx(%vKWn*GFd3iQ-Y0!Jh)EIe&S+)-aw&v4d|S``hggG$Q4 zmyv#fZz0}b@7OQ&gRbno%^$scjeF(Lm{aLYe!JcNU~hiM$lA~uA{6ownySPjE|#6C z*fGYa>m^QZJX&3I&z79gg^i%DMz0tPc4F(gUriBZza8_k?fE0e#dTd?Z#o^Wa1o>sRH)Y_Av=|wnGZppzW76(~ zh3saG_fxw!^soLQfxhY-LB!rU#SKVV%h^>ZPm-E5-TS44E9)JBF5X4kLo?k2W$MyL ziV(PY7w3#cavS5_XsE|;B%hA8=*)gVd*AHz6hw<=A$!rJwW2)@ok_Uz(}kHFOMvX> z!>6t)%tJOO_2~XMM^s_C=qw&4f1<3K=Ip0|q|wnn7>?_K)q*S7P*{amvp)@82$+s^)Nha}(z zv7r@D8B|BeZB_o$AGgD#XQ^|29>>GViq+B40W{x>0tT_Scz9Gc&wvF1fQ@ah(Jd`4 zdGGgf*IJ4k;9xIh>vWbwH81B_vMEbTOD`3D{Pf}9#u50vj(~4plqxWk9FXp~cRZY+ z$;NHgZV`R1yk#V6vtQ1pontE;AXT8Jr+bseU?XOs@@a9#A_&KuPGnUEm;F5-k2c+0 zsRqmRMR`!^GVM{rFL-@KeKUcu&WiXG~q?>JS=nVp<;$g=b$qNQ_~6^6Up8T~_J zpLe-46u?d1>iVJQ8r_3l26GA8{Uc+drjSvywWanCYH0q}Fg7c~L5p4|dli&q*YDqx z5ojO2QF)-zy>#007Wo*E6O)(`{#ieb3WSpS^-pkMYuWU}m*L=^?S?R=ui zn-t2cq{wa1AzQRGybsG38}+bdx;$wkr)qrz-#zeRJS=lPU%5Cbk89>v69duO98Dwp zk>2g0g8{gg;!qA$G-Rp^Zhx}BAO7~ZfEtVg6IN;$KNO~e}saP*d z3U7U!&CL0PFiLSLtbG@Ec{NWRL8!DB0)0v+Li z2=>Kk1|&@}S}(F5tu*2OS*q`AaNIow>|sT${y{;T08PZ>_v9?sZL|56uF<}?){rnp zdZI$yN?av~0FiQVHQ*%HWEM;RR=Dz-znp*GQ=i4e#xp7Y^+cgkQRJsv-Owig1WJf} zXzVNWfJ4N&@d^b2_k9Ze(;H}oB$)e8Hp-`0nRoM?ZqIXP8As{z z1rG0|@BdJrUTMyl1wj5h5{)^hg@(MNEWS#*7PM4TAJ7k-nl$)2XpUM^R!av5VMMpK zs1s%76CkmsIa4f4#BH%NfX@=j?; z%M#?jn&Y^C6!tYGVcRzubZ%~Ves90ptbvaBvvq@$;tFo{`9c3)aPa;2d>YRci&!Z7 zw__Ct5rd=3KvqAyH`gt(&fyiG?*XsR?PZ$lbJT-E7md$ki9t`*Sf$<;zXDiXGfQ}7 zaAfURO>+8=*<=Q>k_i?aBN4zXSUF_%Y4XEd z90rltyouaXE1~CKY3L1W(OAk0_-0KW6OSY+hV}$9edPYiU$D&>v;;0{X-a zI_s4tes~jh$6>6O$RMAsbdQY(`(Mw;yM{mPs83V8AL}b)_FiZ}8OnT+)jIqG-u;lR z!RN|(h~+JN9a$(k=V`nzEkYsoLw!Whv(1``;u0@nLZ~Z#iQYjsv+`GL7gm9&FRcWv zoSOCi9r4YGqnS`>R&qS2XYCoh=S6p4F-9S0^EPrx*}$_r)+iG7t!kMDHl(oFvorww4Z^1+Kuxc!QfKg6vi2cDPEOA3LIF#Yf-m$J1Ptl5$LF+Y??TK zQ>mXhM5&GsK&}ziSI1j8alEmX@{iC%1AxoN*-WMz9@v}Ls_=6>n;9aFAybdJy1j!N zoxaSFY}X>?+>bg{YolG`%~m(FzT?a^1-svZLkOGpI5rqtL0i%2 zzH@hVz2}Bm1`<<}IQtbEorA7(=~_OxjUeiHY87A6$Xe-fjM?yXR!=TZ!s3R9I8kwS z3R-I!si-!(54S9u8l9;rX|kawD}_gj3Go(puAY)#+{;}wMNUmidEFr^>nm`+5Z}*D z{_@@*&KvRcd5PUnlQs>}*)}n6fB#18c$UMsPB>l29JzG^aKY0)vR{GC6uG;=6#Djc z$dv!gE8wUD)V}9b2(16RZPofP;Oj3W&a>sELe+&JF=R)D!`yu}nK^}0m}%cU-f;|- zK)uj<#ZDHgeDcy8#islzjYk;iWv9A|#}JA3)5 zlK9xz1>)r-LXN3O+0tYP#h4*VOcM)(h&tSsEbQQ@E*XZ#^th^_Q&H$X?6lMij8eH@{u)C!WydBg*iH zYUj*yIF`bJrM+~lte&8r)V{)km1>K*cPqB~%5#2JWSTgkQ^&r_gS^R#CV+ERG>e={o9Rbx!pG25l` z_om0hA(WlsRWjdNmHWEBc^t77j#ZjxS;agyGm+QtauQNGg;2UGSj)QRAt^0~liMAz zqNH!X4GhqlsAZA6l-0D{Iu5K(KKgt2wkQ4~BBCBmm)EHTRxhEQn(tvjUhP9k4C2DV zP_9N`PS3mTB!ikZkDKFONvh&>`<)%{k+*`5(dwML=$`b+{>-?kG|MxbmQ}l9Zx}Gh zMtJ#rpvSg}#dEJt;MyX-3-Age1!}`Ndsg3AN7%Eh*9)D%wS@J^ zPokij*?7^V`Zb?DDk>P=B-CCS*a;mHtjv6JEk6gQxb5@bD!OKb=7Qq8-B(BKGR5}m zbskZHrM7|Y9`|`k%==h3?~65)xYLo$xzpuAVooU}xTn*l5Qfu%I%*#&f4y^m%JoRWO|_e^$F1cRZ~}em>i8tPtUTcs-65_N}bFiADftN%R-5g!C6*)9?EN zyBBufFc;!j6tBXt-GBiy4q}df$oAXAovEfhLK^HFjs|OEb&TKJh$TgTU61K*b%EL} zrF`j*m-02Brwy7oHa7NjGfJR1_vt=FG+}D20L0tegFvti`OHbA?r^CZZgo}povY1T zN2$-McPtSqs+NVwEq%fgZkF$g;*NMV^=}Nd zi-C@gow~c6Se3;VmCD5H=Mm}xwo=uNa=o2E7k@C^^{(DPHhJk2^uy*K{*$$t%&D`n z2+UijzJf&m&{b6Nn6*~v%egC4Z(*>Qp}q`A>h^Friq&^d`i@vRXPR$pn`Gshjj6cp za<5OgaEfC$YyM1on56Y{xclfbqee0xO@O5b+lMxyKRvy|t*GWd9nY)djn-YG<&Ml! zqw;)JDm9Sg6t2urEnPRrY3F&~<@Q_ndl9HH1Iq#$HVW90|HIc`2E`S1UE65z;10np zXd}UbOK{f&8u#Gt+QEVaO@QD8f(5tW?(XjHH0}-G&Rx%U>V3|6@AHf5s!jLywbs1m z7^CAZ2D&bBrT*mN6+{2aRAPz!?qJ{Wu>AfmRpvU&ZZ2Bp(jg6nal8;lmZptv(XO95 z2YD|D&jEw^Mx1U;!O=Ng+i4&>h0J3pB*Rw`$O%lND0f&ok&-~4;bpc)beA{D+3Km3$)wS*G6Whv7DV* zlJcA^_k-xOIS;`>JzkJU#X0`G^sMTsR=^MjwZI=pylF!i0dNDy?sY7wg;k7^euEGT zOZb~fx7PyiBm4YZT0PkhS!LpA+m|agyV2&C=FObGcX$$5V4q%6Jt%+?j!|c#{haa* zF?~$j$Xw(k^muP=uADVU_!8V0hSNtiQM}ko%Zo>KNcs|at&nVWoPYdoPO7;0n828% zDo*9{z%O9sBJZ$PA`A>NV_fCN=U$I}oFzR@cIXQ>L`j=mw}me_30#{M-_oJy%>ia` zFYPm!^?06WxdGN8@*u&;LodXaWndga`0O=g8M^SW{K0m-LP76NyH(_&KV>2K@u($$ zz_l_q42rXJPvu($$v)pQy>u>Bq5HlXu7avPwJ<<+JEySCKdu23>)ApJ3Vmuqc;EBF zLlMOkTB(ptrH1w*8&5RytWk^WimC3`_vKodGo?5*lFlQElQ7J5J;d#FSw6>Kw&3$* zK_t?sgAdz;YRFOq8o)v~BuYnNI`sp7)yC&A-Q_DIefGO4OK!bJH9lYICEIpI`t3wf(FZyV<+cfh#|K@Ad|S;Y*AkkcVoI*(4XZ1}SftB> zzD`-(E?<8Z_}^VP^Oh>+#=Mz5g{%B@eim2|#Yp=-XU6RYaSUqcsBIvojUOq%gbj`g zRNLGJrci#b(_li=f*2XEpGamRyIJsHh?92J!#!Ep~P+;m&eI z!%atk_Wcu211?30Rs`BABjA3xwCP)id0DR&U?z(_;$7k=%kIJR>s#+7r=!9;V46s! zBL%#mUKa!?JPm#a^#Jhd)^u1J56@vzL;hLeg2Fe`5J%$P!)Sr+@FhYiUnt>H>;}`g z2;e>63Yb*b0<}Tl1t3<}A~x!8)~r7%a|&KJ*Vm->!s%axC4F77W{>oSq^zn+1_C}O za`9u$E-$;s-Zd?n8E|(ecb4W!!e8>c_mgHn9{!woly#~1#JT+H*k|W5-M6vwJT&oT zcdNQf)UYbH)Y=QR-MyeqwynLAb}R`U

uRf#4J?)2{dE_<7^D^ypuKrzQeC4{t{ zh%kR|;M&G+y!I*y4M0O)Z|H=YG-K}n1BNuY@Xv@!kvr_irlb?Nw|ZPPaVx91j?Ze@ za@rWw+xYR_pU_L-UI7X(8Zmp2b5+z zitZDuCnHPsxEMgv`18H0aXv1T$w)mDkDWc+CRaW`F`J0se@_Y$F`;=~u}%*8x~)QY zc#DL*k!(~5_>&`ZO|k(|k~0u`gW%m;{7L0MMUYE)Q{qcVKT;S0Wp48@26#1#D| zj2!n`KLU%6JdA<_J9!<5dGXd|L5aGOZS2q#QgWD$!w$hkrQAC z%!w&ZW#+tdb;;7FW$w^V78cJ&)$5ZADdoS`+$6-`JiTn1D3mVg{|6v=UfQF=y6zOl$u!3&+9?j zm>&>05*KoO6o>b(Wm247kV+=%Sl%5#f%L*s#X)e^3g+205A7NpJJD`3A{u;W$Wr~wkxc2Br+ApZ0e(Uc{bA?(Nuw#! zNb%l;r$Op`Jhv!F$z}~RZ9ozNQ(ePRc}_n$J|7w9{4F(?ax%*ReQb6kLAg1Go*YM} zIr_1XUq!=CWZ#8jiZ8Ly$JDP3bp=tbUdnF9Baa>_@)CyPOpU%?gcY* z)=z=}CmIaYVTjW8wQT>f)segduJ5y~$mlR(_k78b#=ED$$>_+{bMpP`lnxT zl#42R)2b;ybFYd_vcy!NcT(y|zBji{E`1N$$&s#_*1~Vj&W$eP#CNa~m142`9|-FbabB#;$J{EJi|ST)Am|GSo(ds;Z#(!-;I3K(+w>lBEL8yK~Oi_e%%YbDN_L7${qQ$HN>VLZ85BOl|0m z={Qul!ktHl>nd}ND@M^rY8RK8vaJm&iZNGYM&uudz`IF&H{>A|M``O+I^z96WuHB;h;8r=H7n{3lY>NpEKWD zd^ElJv{LkfskhvqJ3fp+dE!aEL+R@@hyB#lL)TmuQ-}rRs^V(Ne?v;cW&oDjJ*NkHE&hK%*?*29 zTc48}>BVJ6|KctCI&kOc+?9*z|J{%s7D z6T>4T;e`a>yfc0UASNjvb!!_W9~5=jmv**?;>C*(;DtAnB(9$ksx-~iOGo6uON&t& zIs<&pmv=UI@2L=F%MWdo7nnK2MQgDZm@A!bN>vXKOzh1VC%HE}aa>a?%=0>PJ@h_0 z%o}veQr=^F3y|1QP}b^pmHKdRf6k}AIc6~LHVU?_?LD)e)E?1HLSQqs5YFc>wE}C> zj#58&Vl5Aj1#83DuUKo?(WPojh?aN-p5AV9oFWE>TDDMoJ=AGJ>sV0^ER!_P^_*KF?5a+%iRa){+Lp90^NE*YRhMdJU;2>6gEyPC^`cS2ZN4-7mZ{u zAWXCCSA}!A4w|_q(FE;m`o3m=W25wyc9`z%M75f8)<$~aRzKR34ES*|7T*eA#NH=hotHIEcxqCf0M)u;muQu;RkTZ$Fr+oY?<{T*AU zo}5X3vx^jj{+Q@5+I;~)t>`dW$yR&jtIUYI5Uv?TtgUTU!GVvPdf?q*%JPAcmvD z3m=!967yr~v3-XecG%IP6PH?0AAIpO@tmNqh8YutZ?!&z51U0)>%tRpAVx=!OYd%#3MVY0OMYhg?57asGg zK|G#_#D_^SG2uuBOemxDymE+rR_rc5QCG3ULP&((eS_TUMNalmB%*IY`v!}Ti z-R<|y<<&OEk_;9)??<=!Pxc5PAx2!sBOUG6&vh47xNSE44i?>AF(C~#W=36kJKvp- zE*M*$}dnbxtmAUkb&a*VDA?PWO9a4aK0 z0dTNF%+*vIt;L*Bed2d^0)`}g3Ssk|XR8EBp_czhWn@nOH*iFeCc=Yyi;f1Q5ZFrk>j>b!4 zYmrl_fyyS(<8L!mV?p23P{87AZ8icq>|xnQbg?w2z3wHjCxnm1Yd8mQyD8V|v0x0U zNT62%y<_8j{M1{v6GbfwPbXtlIhRzN3%3RTyaVhl*yt`k%$cDk|8nub9} z;$4%%jW#KI(?)d9cUk;(IiB^FN0c_Yp9z>r17PCmP4Ddx9u+>tJ2@dZlsyI_UMS%; z9+AF&mdgx2M)KOMRJw}d zY&3yFpoMbg&nIUfka4e4J1so%VflO;EFpuSt<2U6m1c;ixV%66mpYX~sIAm6GLv(&flsPB8_w}1Eiy53Hd zATi!r530XhlFGbOY_f)+voW&g;wq4Bo4@Ut?MTe=B1b%fUgy}rweQ(k#1XkuQW&n% z5FnNLrc&E2C7-pucs~%iud`a90aXlf;2~m&Lkb0|)1~~2qEr|_FRl4bd<2ZjL*nWF z#DeIFFcT;-aDXsRjF#b1Vr?W)t8arnS?rVr*8G*-1DYliYn1V-0z%$z7wqUT1VVx#Ozp@(9NrROp%ue@@Lxc@6R< zB6ow$lui}@L3*Vkj~WjT|=sq`ai(fK*89&5Dm=lp!rvncKmCmZN~xNU)IBu|~fs@Hz4=f;K}Qp1%< z#c_tV)5$P1&0UxgKgci-Tlf~Oq_i|XF%wxxR^y7W*=XsCszmpdlf?oF3BwT{F=$JS z&BTIf@AgEZP94AEx&X;S%EQu`&7>uHk(UvKPxrZ*gTSC9MCW2x@%P&%FmNKRP4r?1 z-L&ihsQaUq9^_QKQKCgH{ex zfISAr2w_@n@jfW3AoABq?kRx2nSuXeHV;q3eSrp_&w&U++w_=`9Nx`4I}nL%O>`$} zY)r_10Xl)dW%H{j=T7JSATIGHu`axwOAeq?weD{C0F`0@eYpc9x$3t{#4bmlWdU*pYQ1aNYaeqp)cO;- zVrLvm;AF#vCV(3jM+Q;=_tR$I8nlbol1jYCBW*N??A9 z;1y-Z=e$7P=;JAUvfmWXK>^{_2VGs&{m;4|4GxipoPa_zMwjn21OMGED+2A5bZyAY zVu!(N*rr1%xGmEt`xTe}%kFP0wBnX1oQL~oSK5K?xcvHx&Mu8#CVge~BtEDSe00)@ zb_tM!emF#AcWn;fUQB#Q;Z<>1VUoF%R3#5I3YMW=Ssjpn4fE@yYpUC?`%Z{Ju5bNA zjDr#7#V+4kh=*;wV?mnfm6R#XOXQ#8xa0pE`r6aDwPW07uGf@$b)W`tl7M@1X0${7 zHEj7a5S7-?ApHu|!%8ST@I!3tP=}Lc%)Ls$fAielJYj8Z9nZbs8M_QA{|*_|P2(eJ z+J)uv>p(|H*DWtPw9d*#<@oQ~zJCS+ndKbqS_{eCXji3=Kh4$UwN8BdEXueeHd zCO*;Qq{cee?AqTt@=Vc*4C&fc~aYVpUEQ?@JW=O zF3%j%3kIJ%kJGEq$OH5q&We}i3e?`ID@A);T}kHRX3-?D!p8*ghCIRXXB)oEN&8rE zdEOkollitt^;E)`Q|JJ5az7&GueHjVPmsHh`%a1?l&oMwq^CjPWNZ;a9|Tl7;Z@b$ ztq5)3>Lgci_0=t$c<)GWFv+LRR(y51her`Tjl?&$-e(=Z`)N6Tx9TO|9RKkL0u|l!7^mqF!;Yl7I|%{nPDD~g5}rgvlr^F4 zLs5?e=k~m@Ttp>7vbrr;^9B;KPqjq-9Q(^u@S7Pdu;&AlvkzgaC-ig+Bm-+u>;DBa z{!pqlwQECRm%8RZjQ_)uan|obbmYEH_K~L1vo$pW%L`E`dG}$~dPYlX&Tf)C(4cW^ z{%hOW(jzyHi&tBK2RWu;RP@>i{?Z6Yu%|Vlqn@3aj@{%$scBp%TSf4KI|;FIhp-JS_Uwe4 z=skPmZCOItNJdK#@?;?lFT{@rH^R*ThEr{gTE^z5ZFlhdXTLK2kq)sZFJOD*Kpjd* zNRPRLWGjD8kXSBkY5_vIIay!Lu~i>Gl>@1QbNOJAQLdE43@=!%wf(Bn=bFbXHS3U# zP9&uiE~M(uBND;Lf(7~|h(9^akb$sQ#@*|GeYYGfR|!^kIg_VeTo|+^YKby1$4oH& zaD^keevpm5Y{K6NS)WL3odNaND_5M+qnvGdamS#)*DOn?5$rsav4L>GQ8g^Y7l$!5=<-Tv6xh$$jJqSaqLZS=H0LR5V>)ni4ohP8m*u@Q?|4q-Z ze%6@CGe{Nw7aM~<%nfzE_KJ~S=h)23?Q+}*k}VBIQ7h8z2*){pb2Xmpbv9F{-r&Q= za>+e!osH*}>C2{pb+&t4ltwjd9XgZugL%j>3Op24P*?>l*~)7`Lfy@=1QsV@7_tg> zJU6h)n!@{tOL56Y4w}ii z9}P>iGWoPVXTqa{^o>Xq>DT~-?)a?m_1nCfbAJpR8(~e*jD}a_A6Z54lo|WhRZfA` zRg;*__c3He6ktvRdNU??qFz0O_|ck^AR_rzhoDSzU5-m$Mdsq~!FI|E7uboDaNcgI z^|SW&cWyDP+0oU`w6+8>zVk@-*7mC`RdD|$dClwhl(p8amvHnT)F*S?@h9iXg#_8h zgy}pY_MJe0-Co>`umIgU4d*m+U?HLyxZsc%lAIs;Dp?4>Q&$eYny|MM92` zniRfdZ|lG>dP-om)ai)J$3);|K3ywbGO?*b4j-pottN`woDTakA9QHiY0f_ry!RJV z{bu*~luxd$kpHu^I5c9cy|gOU39ZoFI8dzO(%;m~2#!3mSdm$8u?2XQ5=qMcPq&60 zjb^DK0?*MZ+4Z7ZcXX7c;-WeI;NC)h!8@TCeHgF6O4EnTr^lp1edm;UUvI7HPw8o# zh{?MuI13z*2ObR$O6{nssn*okHcTNG>YtQ8y3wBWj9(H3<0oU^1FIb4qr-?}uazsB z(I3(vM5hWmDx}?zi!U?e!qPiz9N_Al@~y=&=;Xd9@x{WsMkYCisX=rtAzjR_JO|B> zr#IBXy4T{m{q}9d`R$2za*H3WqkF4y#uV?65D|@m=#KEHs9-j4E{o~_tm6YQR#qi` zn2`G+0bnbKpsOxZp@D&YGXUnPTy}`t*-oC&z^`7!;tbg z>REc`h-;$jWV_<&4nx9L#U*6Ca+G}X*Lzq!<+Fx%(r$49MT+VwY@qrbGvZ_n8_v|L zV`f!{W$*306@poKeM^&Hg^g%KMydRObj#Or2lbLAX5ic;KqkGT1~p#oD@3;=H5u zy00~#bd~kO1qS6g>nK9c_aTC8Tg0+MH{KtM6zDFx>uM8!;}UEAdzIQXf8UVVP385HY{Hl626m`l&hQ>`hYCuWUBuyTaF4RHKW(ns+Pv zw~&F6lSsSGTiT1$^TttP^j_=x8U?n?=4|Q@1e1Q{h%kQ(F_f~yP4DFlS2_Se*kU4~;qk~? zsk5xjnv~12C+Q8+GiGF`E)CL4hLGpxKLwasIgH%vIRG$V2Yg*ZfaJ_yfIDC3Ko&O} z!qrLL8rnu%VU2|R+bLQsEiJ*Wh(G3;!Fkow8x~|oVrL#j>m)>z*CYwArFdd5EDiD#@vIOHLxQ~C zlWbcIB)k`qr6M7r9cb0^tY5ipJ^eY$!Ob<*BZ6*@u#<_=K%)v7col;Vg9@98&mkFgk?n_b&{vwWXyp!A!bNB{?R}oQ{yD@ z$oe&C>+Yn5rV7+T~({~TMQ7rHVOfUCS|Q2kyI(-&-? zsX;~XgaPso8+p_mJq{y|u5hhiF>;5lDN?CW7k^QT_b^-NtFbE+fnPJCA63p{j)aRZ zTBq_guA?o;nBAUcI~;S2Qymp+;P9%g=`wBIYo`GF-W5I_rloZC*NP>4davjZ!+ANFgC($?&U=8lI+`_k@YN` z^G9O%Fgj(+B0=*6ntX}TlK!p}wwS$!(|^jYw~9UA6ciODeS8F)TUriK>%1u;SbzFEYKVO62KqP1Ra<&?HDZ>R8d-{> zzP+zGCQGjTIL@M_xG_~OYueH@_U?ssCn5A3;Wb;OUtYl0kF%fup=_Q@>9Hawm;J?f z^LlNGhL9Pty?{Qt=~;7c&G|{65^$3=<&%Jc*fRT`1GsLY7IW}tp`uu9L!WMs4SP(& zb~4!UcCP6h!xhq!D6eT^M{aCy)Hox9jC>$N;VA7_{9eF!anHRkF|Kd5(cbZ~m{W8w zZ}7sUUr0>nj2UX1S5o935FlD*Wu8WXOgd}tWv{uzM#xb)=&@sY(B}2Eq3!bo-#fil zCXoKmM{>P;WjyH{1A_u03&#g79~gJmXlWF`HhR)Udwl?)<%*{dDUvV`rM#zr6Uv=VP!nu1svf}Q>II&_EWl^7ID>L6*~mO9>A(<{9z8(a3+ zX-!YgWPQj`#w{jBygwp8b$RyHI^!@&dtm!B41`eB=hOxPZ@Pad z3GUw~WRmR&09Iv_Afw?w@w&QR$o%6{2*|*HWsF3Bs1lCHy z+lK#5A(aq#G4N{W{_kM(v#)&&s*Z&5{)Li~?IF~YWmH$fWg#K-DnyR!~@!v?D!sr+Kf$F64^kjB}&aaqcacOB$ zK#tfifJ3l+DJfyKTdv0imiaP3+QE3K9v`o=zks0I05Wh@ZknU%?r;ii?4-+F zxsfb@KjJeo*nS`bsWe&%7&ig}0$w)@mK%F}VVUlW=_BY6q_4C`_JwPd%-~@b4>2Rm#pKDh`dKJF|1Yv1&iKdcb!dfMX<8-XVi54 zFs%x%Q44gj9x>V5&V?RA?C1bwd%>Fu0wwj>qzd-?4Ku(48CzwP{mpwHv_+)rDyG@= zbYTpL3!KfT@(}N8D&9ih`_!(x`#S%-^=p)gKRRqr^={Ziuz*)7Is*F;h-NdlJ@q7a z)%j>2#H8;T-O1It8>IPQEI#%*ClhA13(OZQkE-g*}jTmGW3n>Q`_RqpFxIcII#)~Bb$hCHcP56+PO2Y3(6k%VL zk0LDX7H*53FZy--p0<{33pYL+J#%CpSTnO+lnw@$&MsD@2Jp}^!s}a!c zjV|2H6|x)o%k>ycG*7mu2_Qz zvz_%kD7gvRM_`h@&KI_>3ICV&o*9x{MzN5X+gf#?FVWw-{h@15(aG+GnZHTU{my0gpFE~m3lnHU!iN_?e*hvx`T$4wko7ho0RRxC0dVWnpN}7aeBY4um(3`S zNGfawi^Ur2H(chg172XBqP!$r+}s`4l}@AgdVlShFLrl!HoCCIq?D8}I=i|iDoiAW zuf`?o-H(#gq?ia@7RE#$G5)Pt0Y8a)hppadTq2^cK-%~w@H`?K{ZqWDdh9-P_Ov;nc`izWq zx`1f=r&_jc!)pF_ZDH!;(Z6JdyZW?!VsFeuzdxH-5il?NB$W@AvXQ)4tn)4Wm2LZd z&J*#pF)tMJEx5(W9PL-PZ3sqm!p8}6L!*P;8!dWgi@s}}Gb55$+Io6IQXyfcH&>`$ zyS3Qr3r#G{C=59^*Sc0qMFqN+I@rL+4%IX}Dzim>BuegwPdxWD3U#masmr@jO%EekJyJEkO&Dy9Yrb9nL ze?MKvO31}sRbe^w;Q_NkrVYh6TfX>fj!i&Yz;%aecvOt+_H{tvTt5|lSTj>*<$gc^ z^~KXMAyHO*QXoExs@d$*oRpz2-nzi?s^6<22srw-EfuTPV}8d8H!CqEx@8II!x@C` zZv63uE-uhM*&ybN@sgE~cq6ZVV)>q6iuGr%BvM6WG16uCX-ZH_Q$ICWeW2LAB$*?G zhK-j3hzo#AdtRpbA==TG+}Sbl0q7oWnXUZ#RiAk5(w0g%itV?O@6iHw<-s&s^;#34 z<@EIMKz1Yt-q5i?`I*ORTs-}<-|O+_q(edsB@?+lcOS?<76+0sLa5u!`mi{I4Aa*+ zSh6A5!e_6NIE>2Bo>|9zy}a6&mh{X`0(XItW7a=gskIbU1W3{JxjBxAp%9An0~}R4 zGevx@2s*W1-vD2RmwVp}FAM_xi%{051Z2=TF+k3XTK6?{&4@|x+fWhGpp#QwW$I=X zMw}8+%M#ezCz4CLXJ|-~Dd$e-b5BxU*Glgo8pkTXxw-l49%7I;EAt&F-Tg=avBQ?A zDCmVo1Gxe)F&vN`knkU;JP{;|G1gTV%Y2%&BSop{^eq=D+l`Kjp;oR&A?39P$y`Pw z-R+hJuW#f%vaL(_{)L6^+6(4dzQpImu-{^L1{#pDvD|_9Jpb*IQiGV{DfIp9L$vM} zKCIKs#EOIRJUNwb#6%29-2?6`_9gCQJV=h@T=EOGUrUTH`g3;I_RF!ct)m=?p&Cdp zZE_^-LEeOpp3boOyN~q^f;p~D4^G**wtqvm#;dHAE6<&MeE#Cxc-IIBbHQa&s_aka zv$q)8hlak|QC#VrBf1(u-SrScBobxdwJm{~{GilMq%s57VLyZp$%9Z_8 znHdQ1Lwj72_3)O3$fE74CfjFLkd*cK{&T_nj#NfAJ}pR=LY5vxtnjDFVN{dM^H8dB zq$W9UV|^7iwIGC2Q{EnDRO*!BoKGdusWT{!K;v%^GsA~sFd0t~ zI&keio0~3R+r}wEd0f|;k~D7#CAT3NPxpISxq!Ub^k^~<{VNs)ivx|Ipe|v*fCO%~ zcXtL-^j@Yfmh0{?9J($&CY{W;cA$hEpO8Qs`Sfs9?sYY>lpn!bQpEwc)jkMEIb(qz zuv&keBo4q#N)Aztr zr6pDJ@N-@v&t~S2=~Qgp^!bW|c5jD*KxShn5td!Y=DKlf@87q6S0p@WO?4&oY|d5v ztpPi#`!DNeHtIyjQ&S0?{)6vKF2?8{$nO?xmz!HsRR#9vOU1Y*HP1WmMCY6QS14fD z-5V?XRS=L0#TsLzlUF zA4iCCXRzZ0ATN*Z`FCV>SslyXWM_B4#qYg@4d@PHNj}y}0UT)YDkD1aD0XGSsP0%W zY~8+~()FcBXx(nroro9KtQPziLbq&RvTP(&H|H<)LUwm`)Vu_~hUVqefy~%SIQCtQ zi?%*s-*zOmdC>6+D1U5lneY+jZ*0>Q5i=WEI{fiu^OoXgq$Ltu+ODd%m*8qTm-e?S z*9HfFAzBe$ge{cTFDD(o(EYx!+w}J56IrLHc4p_4ezJ+ePj!v_yO0@MZ8}g0;&n3* z3a~Is7@GXT%9Q)5UkNFK+%z4nEYv35D~Bl7x@^LSs^ zo_<(R)EFBHa5E{IGMBLfl{d@CtP#FdFG-3KmebV{b~(@s8j8m*x6B-4{xv7B?BpsL#Tkcgva@@Gw;e>HkUHy z&o(czk-b)`tWiI#tY%w&4iDQisJThSo28%^9BJ38YK+Br!hriaw@c#aBle^Fm0sQ1 zN1%}sD6f5t1tz#%-!&~sYpMmnkwe8uB_;fdDv<@)*o@>t^^$kROsB1i-&ga!*?0<% zqKvdD+mX12#JpI4D7^?3eyY^5wcya^7o11MKO8dQ3|AaSrHB|fFm+o^8NT2<<(Fi5 zBst4ve#KbV$|m#BCt8PNmT6Ep+kP;^*l6oryVJcUXnVh=)l4LS+O|j z1Zh2Z((N166d?*|z|dR^n9R;l-e6`i90Ke7O8OCpAb?M;jTkVMk?ZU1dMT^1k( z5|fE0r6(XBe7z5>$x3ut&Uohq17WC4fw6>&$n>KLk4+IX3H%BgbHFfYKM-@;8%u@l zp>^;0wJAEWfb6XI`ThJk!{d%xemR)bbL~TJ<;GY~y)%qUVAEHP-%PWtqYY)SU-S`f zUMA=<51la-Pkg|v`Sz^j8|qkV^o_woTwZ(dPdSOKmX=Xdlf@+?(w`62rZl2AH?JWQ zei*{n`S5S%3}!EzB5F%+9UShPqIqf9NujO6$dZ%+Xc@j3JhwL^gOIvPk0ag?a!ss` zp9O|JX2H6qFJ;E=50N_J3US_>h3t2yb1U^;^rW$(^M*ffKs~L`7L0)M^C_2X4M*Ld z@*bYbkL25!oH)cwPN=fw_R#pQ)`!sceR{hYonS{(Rq@Yiu|iK*H?kX4QD4sxHF77U zsNHD3ch`IMh5K)X-+71{dQ~w}`iSq5g>1b<0~3m8vY$#b-v95I_|qn>svFzqH^;Mn zBivm1G7i@5{DH1@@bfa${?2x;kh+c$BVrrWT)pNA7lj}_KB&DVnL($27ZBb#27$0JvGZf+T z=7n{87H!q|&Tgh_INwY^a~?0&7-h6-uoebk5V~|{wPaTxhzJpybVSHvCkj0KKi>2d zX%+w(%?uk9FMzWKVJC0eN;S@C=k?gL=VZVI<$z_n-{KvV-}yw z4~07i=$`!SGcViRkF+)_VlUs)(@!*;G|>{G`T;VJF9oP)Ow48=y9kN0F}w{2)7oxW znr7ch_n^I)!!p0g{rS`7Seoln-k>jEB`|Z-x}s`5@9&p0>A7%ZDQ~=`kdb&HEHS#~ zMi%BD{Zih4R&4tAJ6%_q{;RCHvZa3N*k@E#B;zoDa*gl!1Dn00;_=y8$+2lQ$FJWb zseX$koWi_00dtff0nO`&Pr$Ded`*~j<*RgAY7=n(eKzyXXLE?YWk+MLDGgk%FC0xa z=cZz5<|s4%v9<*A$(z>)jDq*zxOtT2xo$1O4PoG5U=L+ocH_BN))EpT+um90jF_5p zKCe>g@DQQ>;PrHS?a{Ru{j#o6=)j^61Zf^Kw+Y?9en*JMqzVmd%O-?NP0+B6E3zYs zYD)vNA?Jh;t|zmH8doQUhqyQ{dE2d6l^eysvmu&t%R6Eh_-Cbml zKYf(wB#QIbXh=u_c ziRe48K&brtX`pifRaD~-ZL+bHIc1{$4`1s3Gydn-Z{s>Eh5_%I45_anA>sDtn?3N9 zsyapC;D+*x8lr$*Z+V5xKV7qeNf>pH(TPsGeRNgvZ!175gZwLlgStPj$6odcb@PWm z4X;p%-stdHPkgxAnm|YesL;Yf%A1K|Q5muOKTi9%=R&IJfmzVohB-0Z%;kmLme(K4 z1FU3MAN8iV991zKRV^rRru|q#eAndea2??i|TpNFX1o}XXz<)Ng7Y@b@SOr4w#*V;4GTfX?g32AyF~K z84JoO+?d%cC>ZqipNAWr0|U__@!-`sDWMi5tmv_V&!5uU;+KK5Oz(l#@1I+icH&19 zo%mwt>@BI^+P`x|_lBPbqetIf+;qxv?|$c6i3qUedgn3dRAV&<6fIW0pUr5_E=!w5 zZv{36GBPrp)gFHE|AJ)r)X;rs3O-t&Pc^qqcz z05uJUSEi#hn(!iYd;8qw0!NEc2V=MFUu~TfQ3q2FXmrzc6ClO)Vu|OfRsUjvc zaAo#us6xP&$QpuC-_(?vZP)chy5;*V%bh^HtMYRTJ|JP>?_2_x!8+Yjqjkc_U^8pOWM^G$1L)oi7LI;T>rdK7drBvDW!B3bBNe+NL|AC0qCcNy6zbxVv+iCAqZdrEVja3x+ z{5bmARCkf~HuAXX5B}e~khu?**iYncV-?T;y6_286z*b=lvR`h8NCnYpyjlgGJXr+ z*7yw@u(uQ|_}sM4l6#p`0}OWHil1RgJAj-9MoH4W@A<>CL#+^YNkyP2EC8{-tC)Cawk6b{H+5d|fx28fWNd-qOFd1!F369|)czQ3>lB59F<(!1sqnh=mR4+7|JC@{!&4h{}d z4|N2i;Ns&414N~Lagh|L;`aeY^6|;ZVZfhUKrwC;(0U4Nj6+1ka9RsY5Nvey^~v8m zu6~=I{#au>hXp7gaI68wAmQWFm|Ct={kc%Pobvxrf-o~PgWCdHtSlWJp&Hap)!{^y zMAuKt5!u_EK=(e(7*J2-#JU(+1KtJ$aD?%f?WfV%SW#F8kE;_#s+3D%A@x>nCj#KB zIgzDj2@kD&UA?I!iVeNeeE!(T`mecTEtA2lz3?1)E!NmMml zkhMxgzR~#WQMi&s15(85?A~+F<$}05b5&XpJBsMCi|qll+TLN#7Fhd+3n`k zNp?->@5|dwQ@n9kW9as>-j1NzDD(D1yMufWOkO5-dZN|Raf|fqJw|M7R48Z9DlWf`Z~n(Gy+4#AByL zRVuiggpc8nANL&)*U%93SkvIl6kdGg-Si>2aTH@|9b5FYMtXU*;+ck47!C+EtC3>j)oASd@5yVioPL#71Vd_~wV|@zq9mTOjWMTU}j!zR8mV{MzzECLug4 zzvsy!At1Zq_dHFPv-}_ga4F#CL;@iLSgOKhM`1W!0H34nsmfRp_gGRw!tHqC1z6)m zgxPikAxYe6O8a53dd-Xgfx*YQn~Q(lTuP2F0-H|+{d&{?N!fXtfLz`pLefr z(*%SbXApp;o&82p&396*)Y!&7S&L0weK}lPfXCgFDg*f2&W_b6YOu3uvfXaJf~K4Z z_xoEdD4`k|N6NbWZ0GgaIh1u*OsO#}-k-4ZFJ{w=)F(wU#Og+e5#5ePFGrhq#RFIWHFHw8U`%;F%lm#J>T?bfpDt2?)Jf6ES|sB zikcrwKacdCOxoi3AvBrZv?0?9960hgEaGv>-WZ=1j*VNMc{w{QX3Aj}8!eskd$Z!b z7+fHQtDtK+$ZB$k{m(J=Av@S);U-4dI@9o$T$Ay1Fw$hKpDbaK92C{NY+xjnEkq$l z6Tlx%N*r$#7-zr8>$AP*EaApN;xMFrf^1DxyP7~ePxF4Yj$s&bFbB~5Nlk0d-4X`> zB+IR>!|(LS=qdd04CXH`pPy?%Ij~yi!ymgeS|3d*S157|!S@I^@El!gMhPPaT6f2@g(7)_J!#te29W8GOkVJCrR2;mU@e-WvpPXCu>JpswYLh0Yg@N< z0|XBt!QI^*g1ZEF_u%dp+=B;qx8T9uHMqMw!5s>?gDGdvz1Lpro`>^VHEL9^8rJ*& z{jD`@hVOGrp2iTCze{Z`*Cc3?uXXeC^9*-(A!1vlJRZnMfr`LoWZ5@#W$KhY>=tYAxWrSj z(ItWNr9x4ue}Fl7hSh~XkScQCSVk?q_oT#PBvhH$-kbL~`S{C5uBMVmYpzs1mig-a z`}b+d70M-O@A0??FugrIJhU2YeVHf6GWo=SqQ+|t^~09;-I+e!TY|pKBufLL(DRq8 zDOveOd!yPsLMTIkmQjU!tEpjxP_}=)XPQ%sCgTM1xB@_zJPu16Z`_YbA)Hf?d5!mchHCR&3R>sC;qeh3*#ilDwwXV{F?+xO<9Z4cXRrw>y zhu7*De3ULG< z&zzAFXr{A#Z~NPTc6$-Sa}3v9_h2t>eFgd1HCavi^R>p!mf*~{o}tC})NDASBa4v( zW8RU`t0R~~&56v$(_LyFm0izZ0HLQhc=Y_Zr)I-L1Uls5d&i}O%;ue zV|Jv6k{k>I=w)*nI+zJjlgt&<*l$8feR2uq@wwu zc?EU}{lsXnJz;Xj(fw71u-u|mbS9cC7vbR*B+voqzE1yMg(P+zk!hQjWRcf z19GFZTisEC*xH=p%Erb|CJ)8KXYpiF7Zf65Vh&($7;)kTM0}*K@qrYdikT;3X|Cim za#Bi`urKDXb(MEmEaH}e!<>vyaRGK_8F8MfRdipo0#O}&^52e z2=8KbHMKK(*44@+@016s0{7%<;M^sX1xGh{w*)x3Ml`319=GUwdd^;mT0r`st1u z<0>bY(Sx4fk0qq66b^lpFQO+cP09GE-?sDVGi*s!6{X1o?3kzV)Avgb`1elfLlCn0 zv}PxRH5n{7aT!(kFBaW0ddrnAz3x3B*s6>jLzz(*?wVeF!8cZLm^a=16@FR}yD!a& z^FcD=zg{S|U}r5AVCGAdStlBFMn=LvP03mzAfOJtai|s)0NY68IrR+!@O)Wz*aG-ZBf9BnzE{`Q_O*P7feX>u&c6v|^z% zBekUw$rsyE(?ORLCh&_bV?b(5)ck_DYV5nr6(seBoY>TN2DqIEznSUb0Olta8wg_r z!NHyyxi>w`pYICk)gq){e~0bKYEx9d`lljI-OfjIKQ%&3Z}1k&1*TD{SkVO@bn83H@B#2YX3ed;JMzd;PZS>ohEV-8-^XqpNJ& z8v_O*hs?Uo58JbfBSTPY!#MsQJv9uJ_;iOGva-fz=dRBxCGl#2BPx~zu@4ToZ4qH-( z@p(z+hgujfaei-ZamIAz$zn@NbAuF4pOzlCMmDPhFf#LlgMqKnxEw%!*>fEi8d~_m z$z)w9$J^xHQ#G#lmx;%|h9w6CP|vBDr|0$;hR?2j#$TU$Jt}`H!7K#^ZLe{=Nr{C& zenn2H&FlwdAO{e6fes&AF}U?_Lpqj1U zXrp^PC*V#B3lj^zKeL&6)7p$S=2Oz>gjJ30(a~# zaQ`fHk|=8pks0-X=w65OZCs@NtbjxqgX0G~+K3t~(>-5r0OF9@(*7DJMzf7f!lJYj?5m>x#?C*!f5R3U4K*5CiWyt>qe~3QI%4-X725t|K7H$l2b!F3) zXt$s^G)akT;GPy>P$N?yof-t>qS2`T%*#7SEN*ca6X5vpv2V_z=O*(b3b&F77r${k z_AQB?rz;Mdh)z(trxobx2`8TMWUVz+Ijm5BVBb4eMcbsxa+=(s%6vh5uqS*{8c&Qu;0x}wH9TVjU;Pf?o$cF=E{@tA zX+Os9S??StR&aqc7wq7i`)1B(UJQP3e2TVpe2>MskiE3N!-3){LQ{}s(t|P4-}hO6 zMksJDslXA0p2}cDb&2adn_d_9kN_+IrcysEbU)NBt0JkC-5o;EWIXZ8^sX9Pm?X;B z*o}due>`6Td-hs+mVE064O!fxTZ}X1HLSx1iLGL_BXCqSCTkqo z)%R2)mukjC9z5K+fPB6wM-qN!y3xnivyF|Rv|F$)7^qsIgP|<1b`EBrrZ} zmdmP1K(x8eQiW%5zjaDCk4MW!p^KGKz2}YGu+Mh|TU!S9F>-wfv(?a0 zs+0j#pIt%2C|oWT8nR#DMlRsrR%HYsQlg?wp?%`1RsFtJ4W8@v@9f zUXj+!+)$vFE{?X-9OnfnuUsmkai<)9hNb5DRpSrYYn|ovnG12BsTTe)jK)|2zs{QS z_ULqa3!-)(smi&*)ne7zFFIyE810E_qEA3JpaaHM+c01x9 zx-&0YLb z^DBoNGn0I?1o+lH`?#eaImC;8ROiN39=?l9iTrc+Rf}8tp~JZ{iKPhwIiwcs_-EEH zlq(%BW46pdr?p_a_Gi2&C&MbScaPOAdkocbFdG!sZzzWbK)T{6ot^^wjwgklWK=Q( zqWmVhSD1_G-sT+hT|b&gzvJbF?1WktzHz@*si&^#K!#^c+lNp=n4(vVY15OI_}pFC zZIY`!$A-c7?4(o5l64JPqAiVvULMZ{6%GScR=p2VC>fv%JnG;Q#Fm;T>fE+-W!oOR zob|rc$r^X}@Oxj+s4lbcQd5T^f6M97ZE^irBdr58L(ehQ*{tCKt^(9C(O)uvbZQf8 zBYv??X?|02qz$^w|dm|ohwxNLo=aZnOf9v{D?52vT z`)+NgjrL~%Kyaqt)Te2?uF9x3 zy@D?XU&A=5xEf60N2{!!;ovUpZf}(Ft}_pdCAW3>uSYX^hnD2dxDai;)#EW-E_OYn`Ni zsrv&{)u-$0n54AO=G&=zG$Zs1Zhll^i1LHADrXw&%m+Ah#4+QT1HP^9Q~1;#h>NGF z;|76qhzs?+mw=K`*Nz8MHQmzX%`ri!c=2 zpFxoAyXc$`8J6zftEhV55*gq2D4TmqohJ;p(`lTK(xJO%UI7i+(JT7XVl<122w=*~5 z$+b?=P!L3jb)L4UpYeS|kW0uzRD+G;OaWq%Q1*PmMQ80CJH!-&8C`}fCy4TL^jk50 zIp-TTtB%_k*=yfl{?1oWXIC01dWQTRrazp-_%)23CDQeq+X490fdq4EiH+u`*G=f0ztLO1_PKsr@i5J3wCI%WN|0Lj$ufbeS!D z{#IOE;CTk<YLPHkW98vPMW8NgO=dQ+FUXvQIJOr!Qa}$S8dP6L``oGo+H0Zi@cgy&iraK8)_(%wmhM8;N*qb6m)+s@6fD~X8WWmNbOBp>X<;;`Q6oo5 zJfdJNeM^>C&p!1w-(HJ^JG%^lT||cpUB;Y-6}TQ^0Ap?}6?f}I??+4~%qdw-BR2cn zzI}rQwKq0xnpi?@rgS@ZT#nyd7`ony*ONnjbvkCcuIe zDVeaO8mlw$SR#LF6ferTmN!>}6<7T;IDE%KcnjWympk=_U{7rzeX17(>kdDw(e#=m*jT(_4i z8T4J!;f07^A`IRtK3;sD%v0Z!<$)M=xVR0MAUu82;X_lZaiJ~wsLHDQ_~}Une8j`je#-L?}+=@kz<=faB(PfRw^w zmQ$Xgvm0BNNfZDt|0oIXSJT(n*jQoqa9|{I0Pq0=)jTUUCMK%uTD7VqWGr@i`Y4Ns z_^^Xrg%5&1-+TZ)7q{%s4_QrL=W7*~?uqnMA807P9PSn(s7P;_YN{aBuu zQpP`F7VyWT=nA<#ewKEm(-2kRi6b(ABi@d!TNtmjZ8dKy_@o%%IoXD-O!;Xvn1R=( zzx2(pI+u~|Q;s3*^?`WGsfnAc(~kps&WwVi7unLG1}CR-SnP;wEjmNmFXuV{#O_4HsUtDRq%y-LAE$HSnZFkD<47#o0NA&_AH(7Vf2z{p2JtG* zu1vsuTAeI23ze1-sdgB`FIPb>$Y?h{Sok<(PaC7E0bofuqWtwbY*eTb!K~{tgO|Fx zpxrc+RGT_VJ;6cSfc96zkdtJWN{v>dmHMbr>9@5ZBhMg-CEecn5ajWxB--CU@yQ2UpN293*S)~2#~!A zDKp{SF695^-G9C?E(`Y|(}v*gT2&a|`>BQYqSWZ$?zO{@1M_syK_uR$~SuygU3h+aj7d$Up0!x4p-f9pKmWbJjKk-J z<8r#X)1X;vPC83D7K&1Zgy-Z~VyY(T?p4!E-TU}?%vwnaMYp|1MN3Po2ACF}C|&VZ zLc(YI&{dSp!S5MHpHo1B;kut+yMxZGxxEv}>s)TH0<}jtL7eB$x*zo;G9n^Gw}XZ(=jI5i?DRTl4wm#6_L(8qSZp`Kxcxx`-p}u<3ZsbbJ8@ZT zt)N_HMbRa`slm?0rdbN(gK&9$1uB2m%|I#vgu#AuEriD{#Cy%77VgV!(hm z{x4)+zJ=)ttFh@BH;B;>3Ao5I8Ob zc6X~h?fGil9?3%orF$z;R%}$mO~CVLW{ZzlG3Nd_u}_8S(Ac1tpZ~`Rm@n65R%HP{ zKx$eGPh6o5uFpR{x4Ntr2gtOp?sygy8WEYlP~Joi2C9~j*)Q!v;w73InKPN`acu2Mq5YlE)$06+pVUN~v0B z;2V~DS>F4%K;U@gPld1gDM?8_i;;j}i={0w5gic11WCk_=}#BQcTLyuenz045XN1p zfx7>ZJzWVb^w0MAfu$B+i$c$0vy08!w{Hn(X%W`GULpYV@^YeOaOw9ps2d%vk-6&{mx#%G0Y zjpVYqs9AKS*Ig}RH(6g5a zv9VD&U-q4~#>61!Q~G?xnrjCEWd~%IQxN`N4605299k6NqQ6VsbK&ux^daK@<8v8S z$3P{MjpDD>@668ck)apcrsQ9|sZ^&0ESC(iZ=KPz=b1Hq z9~n3vMY+Xf;eyG$`m78^;n@QC-LbG<3i!NT7cZrqhsR;UgwZ|{EAkbUSXcyk{Ump7 zM-h%>N}A6zoPR;5gaRet`T;VHHizP=ZEbB|!#U(jRA|dP$@_E_G0ld6hd@+s7`lH_ z_GAGdsx zJnATx5oHF#e|rHeH^^tt=&67i8yjy<1MxE7X3OAN1IT64_yH~(iN`Y(2%V9Uq1N{G z6EEmYNU6N^Zi1L?c2oGXn%TqQegS07T#7jCQ7pME3ol=VbXzzkP@}#&P_!-+@}tve z&pR*N(^xPs3_nx)G3Q_}vR$sLgz++vezdW@8M!oveCLjgCH3M?OH&44W<#;|qX}2A zJ4Ya#XT(O0UbZJQVmJKdo>|Q85t7>F?E3Sf6|y4a7QdBH--wSGyf^0#<+CLUoi0aM zE!U_m+5Ak|C$=(8zja^ygD$pmZ_6YbkzwO7Zm1cY4Z(*ctN?2)4JLX|ET60^%E?mO!oZT*LSHMlTfv0K1uhq5I z;6{#0iH_bq+vmqzl-H{&vo!bhZS&y|tfQxXl+(T3rNTc(34cjtP;Fy;CE)QQXpxad z?$%&y%9e&}wcye6#;K){k=oxio#C0K2!IV=6H0k~uGhRnWFXoN+a)Q)TqHiLtra|` zq(0)P=ClPsAUmF~_5Z<_P=U`}IdU{kcb7-RP=&w9I3G{;z|PUH^zr?UQka+JVzfB; zB8A0-EP_wxWJDq*hek1sQSTF>iY}`B&W`Mm+xxbB`KM2X(WG@MlW$gZpESGxe7oSr z#=_S5-*_qTy|qrc+U*M2I~+`{Lzjv+J=)M!@CcUw0GCo2;EW#pSWnIHy50wN)deBw z{%ZXiafT%}r z>#oeWxSo=jH@-B`PzYEgY{bUtNnW3WDp*2T&Wzw_#v>e0+bwCc$|7SHkF1l&jgu-)2rjT1v{px6?~aTKc6?aEyoa z=lJ00=xzedPA^5x>9PX@b0_a`UMHd}TJqxQ8g7vFDFXVBkbV;5Nm-4rJ6|jYfra=dJQ?@&>5gdl>G=ml*I{iMxEEIK( zbZE=mferwPFM-+Z-PJ}}Y;vD1iAhWUq*M~v3=tfp$gzUHv||+gd;qAZ!5adw`{xu7 zrn2^1g7iHe!`;TenG)y!!jznUxL#H$MF;7Aib;5p??WzK0IXM3OU!fZT2oRn9SSlZ z2u3JL|G-Iqegi4SM00_?;+Z|a!jjS-r<(Z?jcUk2q2skN6Vp%ojm zDw$>QB6_=flvckDAGp{oMBN5CDjtwJE(z9rDH!Y}l{OOeWZEO6b$&tgIH%EkJp8?K zkFk+@d+Ac)rpz2+@#$F#djxS20iXQopl*iBOaj6>#zVvFtVLRnsV zh5@;CZG%BPHxg8IofqJl1IVnoZaq}WE6r!9q%ydZy$70nHo;&395FVja1JTk1))s> z92H1BX-{!f3-t_q>ASdJc|lGYrS}7?j zn}Rak#(f8s)yQ5j4w|~`SlXlZ?wn6suPEy$(`1sSZ`W1k5bY1jWoP@JPSrZV`F++p zNTOUS5ADpjHbiq*e_G6Zr_lYwnr9!}R#H}M*Z+4%vcutN{Yq7(-mr6GopjI1Ad7eZ zqFqwEB_fWdG&JMc+OTU>Vr;=;q&isiL7QO-{dm^3R9gEPJ3aBs?w}#3!SY_}cujRd)_XFIw-dFf83tFd-g8rv!7ujB$Ro|gFPUruz&r90sH1yxE_Z3wqc3sKM; zomyJloWpx`46V)XKN=;C#;kvDfRoR;hl8#4W8*I(BN+631y4qcMIw5KGlLywT)Iv6QkZlm@Tp9<4bp*(b`$?6_ke30q-2R8e3rQiT8q20ail$S z#NLCSGI59k+9B~Q#D@GZBnk9!GrER)0#rG#{z{86x`S0xi2m=bi5yt)W)uydl5P!s z%okt|Y3J?Aw8RH^Bo)Z2-l2aaoQ{_t=K*G?a%X&#EY|atzK`f8M#bi^--yYs8vzyd z$?3~cMY}=A8{;93a4Q_h_h|qtx@{)+22GOu6AuxQk%?N03XBZ|`ttUi#nN*$BEf_% zzok$v&fzKxP^yH(Slb5bJ+*FLpC$1DE|ZN>77I8d5c2ZkNhZ)x(uOO{94$M1crP4q zevMW2WbzggQF-FM9K&dY2T~<3_xn^L`tH#-Hm?V-$a!`~JHx`S5i!Vn>*j(qv%y}p z2@wxp=v1t@;er=slDfY_4I=Q1li1oC`hh=uh?Q+yzrP;~Fu-$-rh6CRSg(`sSNmZ* zl?y z(9aJ(B@9-)Ox6lZWi-blN?Y6$I)F8JKWpY`YYcsTEIU8I35?F5qDfvtgcaM%%8tct zLePm6no#YbR2*mi1|(u&mk5SMSV$>yK8HysvDcz&r8yPS=KJOcDU|@qRx%WbOUg?a_@c(Fbk{@tr4))y{!%K0o!3FpFN< z?f=S#1V5sSy4Khl4#!|4A8J;v4aGN7Hpp$v$Qqrn5p2ir79pxh`scDZUBObDb)b)F z1Wv>TpSixz?<~|3`x_1=aTt;yZI%h`;gBbo|2nybma> zM7gBzNWqr-!k6gs*^MX31UwR^78zcl@9Hg3>YwopfMLCYSza+Ub`Hg`!1eHj|@+=C~JPnLrnpx);qir zxktofz?!~syz+{_0f5MZU+pqFSLFQ@{2)64%*tcHS(H_ThSZiEE~Uh6pn)8Tt477_ zj8ngqIAjUT1wsbEgET%98GAo~E!G_QdUf;8eln-jr#0h=LkC5-@rrxKb^D*j@V$f} zf|Qy-=X0jpJk%X0x@3Ao^D^@d_9O`vTz-zmeHCdGce;BnE75T|RKtb;PhgOyq!oWN z%{~)k>s(~=h>1Qfe5KgpH2{N5s>Ciq>FUTBSeQ8uwo7vb(jS-617!B9phkszLwCT^ zJts^|U7m|K^nMv!s}UUtd{fJ~b7AuS4;-jEE6R$&5_Hs}X2yR6uhN_{qfQ{?@($zU z@SOUU?-XfQNp$$jU@ZB4ooila;qevFzA1w5Ep1eDtPK#KS9_4rv)AshBF*}840?WK z*1^0shrQ9X^Fm_A^`TIyRS3t94Pv7)$353c8MmY)r$&ly4_dAz9sZpPEt*?^E0GRbS2F|6JEj>D3c5xdI0mBh^Fiq>(mRfGj z9IdBlcHIH{fr05mbWX12>C+JAeOb&cKE#i~JaKEWF?*7 zTA{Nq>h~7TQEV1WR~o)A_X?+N1@}}+>?O}He*J?sMF?U3-k9#0g;#d2w?gM@6$T*O z`eU78%kE^?6NbW>ZP^Yb@jnvKOtH6vGri}Bqh`Yg=~Cb>;DsZ&(x=44v3||h87)26 zW4A&>XMd}T84nP7M!xx?>;qL7?~ zE)5VvmX-`t#jjLIgNDOU@1Fp}m&WZ|rmVha3(D`Mn4W?zMK`Qi@7bsZ)l)QR=*rV9}3VViY4qjmo*q!(X0ZImjXd7+OJrV+fYEN^ag2(M{Sw^s;OhDa~WwX|% z6rrn_VNaxMrT-kTGBXpe2}dzEaiax*ieSzrvoLv6a@N-qtgEH=6HrhooA?v=Wauv@ z=de_f*7~vSrYttJpb~(DiI6t@4dEOKOOZc0@K>S5vTrrbR z=rg!{p+UK|3)!09j=y=M6I&GzaP2HJI-lkB+Mm10^`G|$0fESy$k7N!!%oiZ1}(8- zeNqO(*4AV)iuf&}zN9*~WfZq!Eso-<=}H}dD}nQ^{=t>(zb4Gq-LY;826q2=y>roc zvcr-+O4=3^OqO9T%>pUgOXD)^sl=JhkHk&;US@ zEIeGSH>Er=m>hI3HVWUE#3vthoIRJFZ}11Qf<;h=OFL2Ga=E2L3+HB;vZhGyxZPOV zJm&GCpI=^pfPUyMh5B9=l0st#HIlQeG6Y5Bgg9t28|purZfPl8Lm3s>2}q9_J@#4x z=$D@D7ZMlEM07nqn6|cQB(*|8`F<@8%<|fF_jH{*#m za^4)rM6+f0w>XLKXV>dw+mq24_B+}=TkBg{J@q+x;-t!8zoA3hC9SbDtbz%W{T($X zHn!)H#l>#_B>|kY{|aew3O#zfaJp(Tjden)C+|e~e-Iy{zlaYm=ge20uFvgJ?df%C zy(I(N77MLz#8N4&8mO|2s_{vP94_GPeZ?@gGMbSqQk4a`{zNRYN!NLp!ADv6}Dl zLy}0UsLWCk5a8!Z3M`uVQ1mlvPDPc&dwBzpCvUKo88>{v0=YRaoTa)gMi(g`57BWY z0f^F;xx+J4T+qJnno|D4S4FioJJU#J$r5DW7+xEF*z`4M=31E}Ww#p7i`6aA8%Z?` zRbt7qhvim~>fk}*gn66Sa}-l)!&BNSw3G0-z(`r~Cf$I-u|G;Fp^xP@frfwp7A`=^ zk%sU^8aFIbb+R^GxviqBx@d%1>ZmiI4Dw&t4>9sr!gxp8Y3g4wY6<94NA$v$s@9Tg z$Lcgp?QY5rr~tUSl(QB}Icq5#8!A1ZUo?@J8!}9|(VA0}D>o{I zqcXdE2T-_DU5b!`rq|AA^pEM(EL# zslbPXl7RX7JHV0h{?k_06u?$Z*}2PM{{vk8*YE(mL<0GxPxpJhNJiX$*lHXB@}1*U z#Fcxa+a=la-4-GW3QCq?f26I83p)AhVuRcBL8;~ayLZAs*7@BvU=th&XpUFQfz_tz zazm9`MMEqAe>MO+hF41`xnc#!!#Y`ll6$ZSWZ^BMR^{qlkgPM+xPJ_Us zb+CBfQx6`E7ff46vjTmzpI}QK4xIC9^XU3+W5xBEM)r5l$`raY>ru#@GdP==XiADv58%jzf_L>)Rzu{S#k*Lr*oU04DO!x(^d`0pB~zF znfe+KRfUlZrI2TqA7RAyL^&HPHXfh1Pv>;Bxq;D&8I}`%BE^34aJ#&I;2=ljBHnV< zK~{x!JX$ZG+`U;*!U=El>;E)1K01avk~*{FASLmuubgi)=7|H*w+&%psDyWHt&#ox zr^Yd199kZniK1+t=Lar?m|*02mnlg{_j*Pw_EO5fkXlCR@4|T`FSJf=Zu#yy^FROC z*DiDRm(QK}Vsup|caWS=Y(XBm8k?<@V-i=sidXI;{N(uYwPKzbd~RS-3e4l#VsCUS z;u2$@-a-@Sd>lB`T+)zFNSco485|UuZi{WHc%ZvppDPdKP;u8F%el!kI8D-Wtu+T{ zmsMR|tKu`5x7q1(aoaLR*qG0H@6o=u<)}o1w=5lGF$HyYLAj0G6I%3Xg1z*u5U%tL zKT7pAg2iU*mo#pqUu&&ppw{(yxoekHn9?L<&u;e@86p{6L$w7B>y1-b6V-kLxng5w z11W0F0)CNihu%lCv@ABOIDZi$Abp5ms-eD~-TTA4ckk9uPZ5cUiDdw*J>x7uG)15f z#yA581|}{g)dRH5xoY~_p;%}+M7cJX7bdchE@$KZ zL@u;Ojj3;39LhLOz2!nzKQWn@Hs(w|lfk~lbdhhaSd=}XPgxnQsU8w23U zhlr%|^=>0L>%6t4rP5}NS1N@CHdizPsi5GsJEPh0pv#QW-29kUm=p7%3PT`H3kOOg*!yh3p^ezsF$r65ns=-efD?1BJZpfUw0t* zff*vBovFv@_|p91mgTU(U-&4iDqTAz$DGw-RV-1*e;s=swfcqi!=QwBuH95Z$fuvs z{$id%y#D=RTc#d>tOw}#+S;xX2e}NQ6L8Y1HR4LQf?oNpjJX;vcEiS2Htu-ORPYnO zbYE*Td!^~MwLDB7(>;@pID*9!l=PW|IJxU%osg(IRk;?2f4;>{jsS6>PgBhq(D;C{ z|N8ioes>fu)ESlt^Rni8^;>IxE#1Sc=#<%iPV$>>$d6qz00kkU?e>`-9|u%wn=qP3 z^P!C%{yMT;!41>TEESUDa+zOD>K+(>L*4fL=D!G2ZbB>R4Ngl4VAbEfT8?FENq@^( zWNkJ2h`~CKKY}%9yfrN5BMU)7X0kC9L6QdZQcpqifxP_R>71M=z!#3B^MYQlHau&m z@s>;Pc@5L9z%S&Zt37{!rewkU>xUI8-oo$QVl4zfxf9jmdX@ClO5R~UT=DO9DRQ4|4wOl~$Xd;!y>vzjs z^Gg^woQq2YyJzZuz^Aa?oxjjii4H^)L1h^IGpi%pcZ}Hi>1M`A;Xb|T?M@8tf+!9| zY4V4FS}Rq)h^yIv*dajoG7W-TGwD1qz6HIvfXQO<00|r2RTiSIO92jU<00JxkwG0JP_CrEotkwBFV{LtF z?nq3>l|)lKn&g2?Sjm(%$$1H==i_+(H$js)KmU`vySu^t#f}z+93Y;m%BTaHQ9onM zmTQLr4ueEM5@pr?mjwJ{ZWWLeFxvWtbgT7$phk+=+U$FCkS5UL;9`UAHCU{hQr&NOghuePE~C} zRP;eY-l3aPZdzSDb*nxN0qu5|GE@q5bjIPWkaJS8Xql7D1vA($y=1Q))MQ_aKJ>vG z9`hBr@*yiqs<`E?)G<^Lp=|vx8YgaV9Usossib&;78wk0EC2zR@k*-TcRZZ(2^adV zRoznW?9&0D3;?eV4w+k@rxR`U{OFz^=Z_q4!kZKT0oF?FdQZ(+Gn?oShG(vYT1jMb zt#(K3mrB0R`h3SdrTMXpW&=_vo((`X^%2@7wvybL7`Z^z-^q!(V1sY)kqg-=quoQ* zTD@bA36P8>`|yiz6$EAl>Smvv{XTc?qtNb)+_Gq_^-WErp7pGUg%E>Ed{`-=R||!D z?K&&W%Wcye@w`C6zA58Bc4tjyHR5%zf4?D!*!hr4lBQHnuO{&U&6xtFj@9x`<79{V z-IW%4{~h&Kc{J~HK|ujQLFOc2jTsXYbDmPnV0Ih(9gfvkX>NHrjT!LKzWe-3%#D=ogUz z1W)8}$VCHVV~BvkgDW5dj7eG41~4QcfiFB`kPVRpP){`xMxW0QnCa;OlrsTgG=@K` ztL59hJheNvJsn?lO?rT?X&oArB6WOw`d+C1HmIFV6p2}%?AUSQYOisNV9tjy_oYlh z>Hk90RMYcgO`6I-ZbJ|_#EbAk-oYlnI+zG3Je4W+Ss7Wkr1(ZwJNWe(4Ga`0DrtFs zY+C#+a55{nO*Rr5-J;=`cB1my)Y$&Mn;~^{$|Js478;|A?F_yZ)JlCL8s1=s3Q-EW zr>LmkxpgpI6YZ?bjC_RenMXL;@((CMjDFa!nvbjjMKNXQjgSTUv}!0LM;y&84<|7{ zP5WkN3r1U8t58sU;F`SvXfmweReA9;ik2h?R_dRppR{xYK!(`&4o3iE!=0|);-Eec z7(u}&%{WJ^5#4o>G=R&#SqzW^a?qxx&CC!PUZdKtm;`?3W17vU*;Qn z&AM>rjy!qtYA+_-0^!|xFA957T`<~Co9bk?It^-|ByFP;Q8o1$X?UX z(z-bz;qAYkV034uheCe{ic$bAij=w9T$yISHa1{BIY{d?5GbNL7e34W`}p8S7xE|PhagW7a?_?{dy5#wK*ZW`@O&%;CmmC zu#o;m>qO+W{PUTMCWSczpl1Zs#1W0C6jCq6ej==aOENNb*0UvXl|MdgB4_~wZQ56R zxSk3A`Pk#|UF|>QPRrQ;o7|~y_=XmvdL`cq*$#$j1nBt2*`;5C=vSE%cn4x>B1mb?m;aB0o zLL%Pr7BXfMz^5JN1$2oWGGFh=1F<{7e4cn7)t znIKP<5JxWm+er-)pcK*j8Vl6ilyAfL{h)w{v>>Af`BZ^4;0ny0K%<@mA7X%??(x3M zH-w>0hq&$~JZn9hY7y)qW4?hL?5KOIlJbJ-_iu}>!qdI8o96nmz^YJ6&g(n=-?H~# z{iKpO+=z`7NXUl9(lyO6P#-kz@Z{OFx(FcL3o8s&rp;5*4Q2G@>7 zcp3+^DzZTp@0CTm0q(so->emH=Topm37@ezfo4VE>0J0Qc~YAm`K=~d`?KU65~JIR z3NQ4seXUgyG|0plfk_eDM7W6n#G-+8a4>l&f$I z;)i)PR1Hfl!xRO3{3Q1H`}bT&GMb;aJ|mRyhq2qm{R3**-X#v6%PxiGaywjY@~ngX08SduNBV>q zhOXQ|?3WL*HhOZ&sebUNC=;U{v~U3}lQ@9z$;>vcAd**~6ya`Dy@bL2;={Dl4`nP@ zDknhG3CmhAj2#xrEc#P#O7-K=rfPf~70s4zv$IvXU15haU@Xe->Uw8ULr{GWJmtd3 z9+Z6Ff}^0E0m*-RQvDVY(_a8{7l&APB(~_?^rWJkP%SW1y6G2yZz+Fcjh7c$1k>9Y ziYL&uym&-MI#&Bj?u4?`(borsh({rnRQz7V54{v8fNh88RYqqQSI`<)nQ#>z2NhMo z37k}8OvicxRoJGqlcNw#QpbJ|L}AI(!DX65PY*MU=PU)q$)8f{&BMZ30ta6Pi}-Ah z9z@>#o`lJLe7rD|)?0j{Mjs1R%!4@rZtXj5^LE34>Z%P)V1Jc7z}wPd+5ju~P+X#k z9>t%ibIQH8yqwEekV;J?+U5*+E=?&QF+NTqju&VnUSZ!G_10d za~zUNy05}iOxz#ZfI5Fv@cH*>v7#;=A&84U&g$bvaPAej%&B8tO{d)ry53RGO;z)c zo>TbTNH00H)r=vK8>x>^N|u^qG^WWw(T&3I(Qo`>y;%jvx9DMpTpNrIZq8caC?KHAjn3-F@w&VewEk zZF;pOg9eSY_WeE|JF}e7Jl7pV&!|YYya8u}wpMJjH8ZZ({e4q;hs@|`p9YTJL@>qy zlz=GEI%Nz!cQOjH!_$nau3hy#kiCr*Dvmmkmk|Kdb2k-(VY?$-aPumLy7i+}EU;w< zUkr-;Ucdw=C=Td-E-&q#9-OSp_&tmMP$0oD*Am8>!99x!zW-Bx6d+$|o8|DZ&m;l) zIt^2G&G(C+-HRcvVV#F>gQBz!EHKf*2B9gJtz4H<$niAT z^9qugU21O$12EntIpg%$B+#wcZn8fZn<~l|!uQTeW}T|LJ=hI&8;;A2be_@G#Sa}! zZt&ns=P{(WJs*NlVr25cyfm@i5dlNfT~914bbXVT$tIdhXShk(u<$Y;iH*YNGlX1D zwh4BWtM7R&$c{uJm&Nt-{hhf6177}KNCHGtQ@ujxXV(xql%D9x%RTx#*7orUzG6Mk z)RUt@hx)q{)Gx3?=X5!;YvYYGpOq!yYY;BB$nG6$jHR3>l%jPB%$E<+cm*ewL}M0X zF}r1sKsPg7!g$+>W3&F2@bwNYnM-pg_PS4;SfKR5YoEh(V*xNTZa6QobMOJr zEE)-X031rQ(9-I7c6=5j__v&FVbsC$7Ve=;s+`(8qg9FOC1fAdUnhiL8(l1{N?et&rn;eS_tHx;mJG1`1R=#|J6AL!n=5 z`vLw&3kTB-1UuBD$d%_FH%?a~MRs)XpZpZqBT*xCfzB~;i=BSezENSRXtV<`Zd;0v zr@Mw3(yFmIOgLtnx>bW&pf$(`RMDH|Wl5`;b1}XBsK#M}^BCzmdLuRx{5LKet=}r? zlU+#2f{O(mIbEyosu-G#NRJV=7o3P8q5>1xvwYIrO-C$w7y{JO-duSa(`{nT%SV=p;}bL0qEiG;wKw2(Gkbj|A!!jKy+s8231C^24nuV?cLxU9E7>g03W_a^`=N zOV#x5MvWt%Y1(SyjsoRY**EVz+*?0|i+g0R&0`816GkG;2yimT7|bpt_yLvVNZ z;7)K!fRNzs?(Xgu+}(n^y99TFySuv;ck%Y_-o4N6-S^ya$2gzPC&nlW*1xJ&Et${! zJ<~8hk&e? zE)E|6wO+{Xux&H>_uz}q+c6fWEZ=d-W^@HIR}2B|pMMWb$pL{WJ(vfpLuLEc7=U4g z@(`~nsa6}*5Ed$_noYdr*?Jy79mP?l!U@J9Qz5zzs2KwQqaD|Ig`XiPviy)dlG2Y+; zZr3NS+jY_Zo0$4vGhJ_J?Eg$+s{GU7QiIjjdY7N`^F8xoomtTCXzG=1^J}b$F_`96(`VzkX5iYth)**q{sp z;S_&h)DRIRB{aQSy?;;;G~-f-mjK{A{OtRk$=ElHu26zqwoHDXZ8_X2$g#z6D0EsZPTyl;T4U?PxD3bn^>&h6ighG{LQD4Q52Zf)X zR!bC!TIK@Q4tBukMN8kf+;I^pycrm;8btSG)3-N3)pPeUxK`2D?A%(f8P%sQY%9h7 zcz3ybbYmk!;3(F2(MnN{xIf#K-q7@*1S`|GiUHh$JI(XF_}!^qI#Qb~%-@LfRuk?K zPc#s#y3Fi3IGsia%bc=QX-Ff9vzi)5b%11qw>Kbe%RNJ6&e+Aa0@;3Z?{GzvD0auj zzo4mgkCQZpQJT&x(j$pfm7!GuA=&uZ-hac&GJ!YqAr|wTitRnVE5x~j_|~?*Kv!3= zF+}y_Y`r3la0zYh^H$IpQTol!LNUxTuf}i31C!+7MX%5{J==}^ACAXSShs3T=NhYx zD#D!Xzj{{B9ysyQ!+;X!#$qj4tNV`!h^j4JnQH{v&wwLGcfCme0C%?6ZIIK2ETO|q z?ZK_9Dti-93sSKG39B1Wx=0##UuAXOa`zhQSb3!BPjo}2|H8<|2B0A!ZJF6c^6QO@Vz+1pcxzU=~it@ud%jL$d ziUTcxVBPnidu6W!pjSOhaE_VzIzKk-KND79qc+0{^l2fK6oo+C@56@=f~u-G zx_Ww59#4*Awcj;N#o;|7aM-M`_rEbfWVuYw|2<`Od$vXhsL#5C2t7aGtPKyt0n<`? z@+4z95p^sV3GxH$q!Ss^Fo3qU9oCndIjxPYtz3HEct_yc-v$-9xt{)xnxt2z0|K#Ig%1Pl#~@S%dJE){!)7)!$oIm+ znpU0XUR}8X0$OK>LsQcPx|-jDRdmL5wzY$y4XTTy zlGgCmj3_1+B(_yRk7B2f1QqoN?*SEsxmR!kxbPn9P9kIbNiNpNwY>LhwRap@Z!V`X zGNgElbR5jr1OF*pb$7bj`L}S@(DtzN|BL&t`{@6d`;XD_>=kv`lb7t)MS%K!!)YGL z-K}k&(QeMacmFlu{~IO-J-I1js{CYc_K(;Vq2b2W+Crw?I=r1~Tn$`JX}R2ckNTHGke47i?{v@l6O|hcF`juaKp4`^cgP;^ z#V%c2#?8GW($OPNctXf)Kw-6U_NB|M?=;)L{`UM6Z(y`Tl(z(og6x4pbU3yldd6y3 zy@6bGyxDsBM;0rX%N;{4@_C>?$EAtYMvToYV7++Go8iOpW|$X3dIVV?8T*Z8IB7I9 z7MlwbRMY|KnDBH6P~&*~M-t0oVje}C<&{IyyeGKja#=j5bdk}HU(a77zwk#)uPdq9RE86%Ou`1wBGQU z@OoaLQg64$Fmp=VShP}V*qyJngz<&;gwQBe0y^18czT4GU1pJzxJ=&BR0(KM02>~n z<54u{3{R=UkIsC>BW_D&$?n$#vP>c@)HOT@#W(AUA zvLO&qLqdL@xk`R?;Ke_fdqH}kcPz|x`a8ml7C;PZ^GAE~B=pG(Tp6Zq#`Z>c_JXxz zt-;v^_@9f6zEb#2fW}Q{0R^TjmFmZLvtF;n+z8K+ZCcd-CWiI@P4h=*%Y(XJ1%1(> z*0FOYepOsBe?5&jcApu-J-*gP+d#K&mR!43@P&vf&t_U^qPs%Q+4G1faS!_|-2QNUarDBh z8+o7p)ActDp;MdS_5lB$Zw`p$4<*s&MRgPvy-Do${|3Q`Q^sOK5BvK1fKiC#9K)0m?!XZ}K(|i(4T*Wl#|YI(bVF9tV6htioDWQL03tfr?NdZk zK*&8Q$(s0mdMvr1qa*v_oKg~Q3=`2mFd3;C#G_9R12EM0EvOai^q)bk!<0B`Er3vM zl35df9!!;x?=C)gn1povD+^`sQ)f5cH3IR(1mq~!0nNkV)&C|$68$QKe$*v z(DqgyPdWXv`OyRv(sZ_1SEAf%xx#HZv{GXvdm1B1@e~(aL_YljG_B>j-JS9xdd~-^ z1G8jcfl%r$VKqaR^E=-Ifti}x+U*%h=d3U?^3vwc+#;kRUHa39N&j3{gTEqEaZu8d z4(dr(9(;vj^O)@+N)h~MW3`ut#;4PY=2Dh?Go|L|qa(~%lZE7BH&?$$%`lyt3O?zY zy0KWi42fp4#J)O?oIO;|Q>L4%yy9bwVvh^xYy6Cl?pua^)7x2o`?^sd8(Wro187D4 z&%RO==|*r@hCL`6c6E5jrgtUc%xa`6*!K zbL{W-WnR))WJd`intsMN>(>EQ^I4m-;rDiWv25{GsdGN)oamIFSRL&R)eLmFTqZ?g z-91pSnh8Wb@KJtHc(^DzyVE`QuX{!vFNKSnHr$mF{PMkn02Whd=aa>FeyD;PZ_1)1 z`q7fLR^p2@GQr62-?=^R_YJ1rs~TZOoop)t4h9i?>g(I8QOxU4&hf%-IiE+8WVP)a zzDlK^9iE1VRVfLkyv`@VjqtugTxRfZok#{VS^u!2OR4&(f9VwpLl`&xxTD5oG(`d; z9`*|tSE;O}30L_xI^?S~J~r~rvfx5ptI`zW+mMh)uPYHR>^8yF|1AESCD$#h4EO5r zd}&=Zic6N3-FSGUtm)(4Q#6wXLqX9n4b^A%y^9AkBJLzzW495yk z|1G~J{u2SC`n8_MW!JdYmNiqzO$L$m{GW*|l+>xT+tU|;+xGYEp*e0(lD5-skYvfDMH_|WTGN~#w5qB3e#gaQ=QN)7={59s@Lr~u<~{t+PYWQ7 zddaKK>jvK8)zSG&In%2dH=&f5>NrTHSinFXDW!t^Y;HU5%LvH2ty0af1!R|Kxm9;J z8`li2xb+%9_QXPYLml6~-}a5O#fiWdyp75DSoavViD<88n#C~u-r9+Na=uX{W(spn zGxBbYXldiIFE+upU#f=&b&FplF7QUjrMy-7=H@k{uiHs{lZmT6?PUWhmwzQ&_O`!Z zhNt(*;biJ77Q3N#(H$VMJL0N@D;U9v>QFN88P8-y{V1d10K>)LFQ-+7Oo?pQ9M* zsq!n_E_Y_mF=(OK)TYqNtwJtdm-lI@&{U3wkyG%c)-Ho35H|PBc7C!(uLWp!)CB}vz8I&c zr`ou>3A=XM(9Y(#U5Bp%p!?hjsC!^HU^1Jf(@`3~y(p#8`lfvX#?UVVu5vl{<;Oy2 z!8>=x)=DNo=_0>S>e5N@w4ms*JrJ!ed(`aX^iCTUqW-rbAj&+Cgfy@3gl_1%w_KSj zX0pS+g&2oJJ{Ip({!FbL5zsHkBh_tc{EXEtx4?@`%}4%s7XZ0VnKAzljM#SDHYk3d zAM!<7Rog5&b8A)u(~qGswcTJA+}9$u-EB0lkC*d)Psd1JDE0F?`ge>0GAq!!=fN1T5KTrJn|#u)ckq^>18qzMCnWteZm)cQtID{6-5? z5|qg3JQt=TD+4v2PfENL5$d#zijM*q0Ph5r$c`T<`B(CPJPClXebbh~pKGqkM^a{y zhTB~T!V-@^e=NF?-Y7I6?k@Kl*}6Nm?a6qU1Mz+}eqXuLffhJv3C_s-6wG<}XV!Fg zxivl9N3XqGLN{YIgv%P#@hDUiMZ5bLh^AvW~pespd zeXC8vf*)PFU7h&&F(`c2v~so1s#@~P(i8Ar;cn#D>K&NL8{p5#@AYVr9SqHNEcR_)?4nt{CgJ;W~{eT zrUzyl!AYo*W@{p53B30AuEQPS_GNC?o>$c&^N0mdYXL6nGVBzR-zVR1KVvY$^zaJKpfISNVI52G}z$yw~btx<3 zr>jrq?-7+g->}#)9Ce;O&V)6(Qmfl13f|s6IhJB--f?xwt&WYBSf~9JB@_R#ExIwe z(jb$yx@33-&fs)8p!qE*eXY)X3E>#`uRRwVX*wW%Mbf67-B1oz@%E|{5&mI8faS(Q zp5x2MB`p46Wk!EO%E3jM3y|(~yj@|`Pe(rxvlc#1V|1-YZ`Q1H=8uPm;~kr`=SKWg z-(RVcx?$Dla1GiP`96FLEOvl;2@MYHd%pvv06+Q1wl(8^gj>R&nOea$m{waDbtq_A zUXu}+ALl3W&7hyp%5)^Y3?U4X_R}s)6%bKv1z>6-%R9}&1L}_^<%3ji2O7YI^Tx9_ zwxiO|Ue0@vpk}sJxG9USYg`4Ew4`NT^6WM9d0nrml)t7TwG$#9dq-hX-+Kb{sA-ct z+qG_c86WVqXFS+Bv@_o=j|HQUp7)~420yxw23cyWS-{{npyT1uu|nDp|KKQREe!{H zvY3GuW*^-_NJs z>#xH3O45_<)>`bXCUL%KL|aisDcoVU-^Ey4_u18qXHD+&+*5xWUmj6ERU?}aa6IOC zTu8oO7eL+bZ}GkrunOlQWQeXdv!}?j^^)ow(x!m7l8(u8_6ft*TV1q%NU&q;9*mC< zWaFia!&}in>2D++h2+QNkbQ1C73JVT8@WK94cdks|Mgew2UN0TWmD>_@u$#4HTl_R{(YM#sn-+p%?{wS9jvNO&;<^#f3d ztvr~|{D{+IgxWpCyV`F>Pf;KsI$|*Eo%}GS-*Znj7p^!K+Jv@L!vUNZyI~7{c;~+y z6EAmV>+{ebfEKR!OYRcmN9_2m(i7t2$+7G_KTCcN1w1(tl0u8-2Kn(c)+ExI2pYra z!y%P={HqFCrGa1q>a)n&MQ!C}o(t{vsrrK)Jj|>Gc^&3n=o&|>vI~vORS6bLSaaw) zw5e9VWYKB&#fuL0lnGbhCsjG3Ms+#OxzR`|Yibv!s`cA(ns?m{k)G&08mjA~m)0c$ zT*x*YvRs?%9}S+SEPj%&HE}J7#a}0$j88ZOIKxk7hblYR;L=)tz`e^(t1dau0me zc8Hc~)cD9V+0X6O06g$Z3@>@ArNlVTjxmF$l6}?R#r#Spbc#msHMMIAwOY>Neuoe& zYehMW*+;3{hpiqd`3i%0#rvw|$^jCVG7O)lYy6AbPx1CB?m6h=d9Qdlx#Fs47~Ke) zSE|U;@F5fxrnKruL(3hWO3jz&m1&o(7Hr?%9<6#$gK`C6Y^TCsTUzXs;+*^=e#8J0 z7QyI%BLwtXJU9*H+Y{x0W;enA@&Fkm zYw0E<1Off-j!EG)5CT-9E$ zo+2V5+Q;)%fRLX$zR)eY(u_k6zVYtBK^B)7+`6jh^6zP*InkfPDiRQ|`B2R% zkjxH+6JiUZlZ5u-paZOp2Yt;X z)@|QCUJc(|#%p3OQN-8T*rL?JTyFMh)w7RM!LBr2!cOQK71h*{2e-Ym61Oqrrc-nB zSBO<4;WJuOq)DZoYcOB_i1P|XH8w1I8vR1micY>*nZw=pL8^DCG4#{a&l)?_Qnzcr zfYd_fETIr(ILqn#4=d6gcaW8*6X_W@emsG2mYU+H8*T5}mP!{L{0v^Sb9v`}S#0RG z&*D!!XFTB8QGu)2n0#X#$L*0as# z-NU*57W2;YWo>$3d??<=xxN%Gk;p1qisYY#foe;P{R0d;?P3 zEy+K*M%qUQKi>P3E5rxCUGSRBZQ9t4yP!>C-JYPpH;A>9BNakbUE_;C+`#sGp6oPH zvL#}>VnALDTH{P9j3~)4ei;^7Eb#l${+>01LH^se=-TJLjMz;zAe>FB3?1L<6lQxB z-6R0-@~x`UM3e9qKXvE^_M`mI_7^G-TBELZ;qUYHQ=r`f3%)OFL*!?+OBe{+E1gm8 z@iz_R9L1EnrpnlX{H{knRFtVWio#psQ2CU)w#pU#M0#3D_losrv*7GqH`he8lJ(9} z;97nC{XN6OC1m~ndT(wi;C^1mz#sr1u6ko>Tv)u$Ozs{Y(W2qDfOQZB0|QE2T>QwWptw* zrn^{_f8QWVlg$Muzcbe5fw~0nfn8Os^Ih5~&o6y<>xq_ju(cE+y|RTU^>FXy^OsEt zy&qN6DVSV;%wg1#vA4FIRy{3jM#{tN6i}ufx{x z=Le^!6xgrsfXhnNsby?PJ89{>OujHPxQ;xdrw^qi_fNF|r1%`bJl?_Zom9tMVEgzU zHz6WL_&bN33}bMJ*p&b0ft40W$EYa3%7_*5sd0BnBChlj^9h|=r zne-`lmL_39+kUgYAt`OS7Srj*CkxMY4RMl7OAiPq#OQYXpi@c`0As(*usOeoau1sgHC_d*LqX{D)Hwx~fSNKt?Z1yC}3S?|m z`0eM58e)Aw*GYjYZn0-|B4*Y%hVgFneAEMU7@5m{}jhWe>tkeoWok$I$)EI zAzvT$D#Po?{gTpE73qflEnZ9guS+mfOOwxbd#H~LwSwITxSSUo0>7p!sEzN_ydw;4yC%A|0i1u$RuB%9_*bEB(nm#2&Pj2r_sBbR-@8r4w$#?{qu(A!Q=~;8f$oqw(?qZxQx{#P95!L73oxK> za6<-N{0`@A1kzWOJ~2)s^zEWn&}Mqx3G?YHM{Z7k!~RdrIp~A%1?p&TAsqWvi!ZtFtNcf zEF}!3wnx-3W;%w30b7OY$njr(pv@ke*RF|u+)&lr6eV1&OT?Os%xlvP&w?=zK1LYU zo-V6cR=<5Tm|!&z^&Q^pb&vg|j`1uge=wy=Mwc!7VvmSdIPUW@V3eP;ndZbQ^2h1P z5I7mot+3bF{T$|OKgTFta)__6lhh+RyJmJS7of^pOU}awz9De)d~AbMR|=BK*!H`j zgd5LedlDIvVH8ZK?g}4gNv4j^@9lV4J#T^^#)Y5HWd_J2kR+dk3Es_wQ^AaOu8~yo^*(WnJtQ*q=ajgo2^tb#3Wh->8;(`ejgmJ?y$SjX|_kmHPeL(Po``#UAi*+~XMX5T?x@VB@+@a2VW&6zec?^L!h)5-_pTx5o5|EG z8q_b4X5S3!S1-O}r>XbpR~qi+fKGq=PVR|T4dYW1$M&%?wqnB~xZJXu<&#bWvsQ%vufz zE;5fjpk%{rHXD=oO5H@max`*nr#<)q8m?KEnk(UHTB z{igyfoGE$j8ZZyyn|m0}dQ*my))vQu@aH`J?bW8-3aR+1c>j2-)|(*#{{@kF!5`I? zF_ns}o-{C$R?MP0b#=f5_cim`840w!f^PQCOvZ?h=0W@s%ffDQ&U^L7QP>ZqUW%}i zUJHn@R+^ad?km^Nc7+lK0&w;314&;G$$a!xaAK!bxD-Zq zSDVnP?5s3a`MnmQ+^h@k{nZk$42)U%h^pBH5qy z%$UlQg?ZS5%ndX#jnCgyX-uXe{=CyTC`RbF|9oA3H9&O01D#N>rHUdI1B$WPmI!Zz zloVii#vvIJ1$lhr*Qh&l7j|*%eclF=!BdxV%vxn><_Kd-vk=+^neU}#xU5sEe1R`| z0Wq$o=m|kYIAp~SpiKI{FmBBYyaYBuDn zZ{HF^05!IRO4Iid12{@IKf47vwVY|GXBPBQcu~SOTCneE%;=?oedF(gW@c3rPWLq0 zL~)As9dYEU79u?ntE=kUh3LyRFe^$kN__R( z%Y83m%Y5Uu^8)!H=?0c{EZ<^(9@85*CX#HX-%0LxaW3!$?Xw5qu7QH#D%z+c!>=$q zLGS1sPFF6=lffAYut=$3N|?{FOofY-ha(Avv1n3(lRih{bL5sVL7^T*+EN|y*ChV< z0C^p*;Xe8-V%%*y-7zlUQbB|e1F`b-UP4l1ih$Y4L40T#6Xn%esW<8{@>yEn>Pt3L zLSA@YH@-j)t~`hZ#zuG%2Hj&!uOK1+#w2qDqX4}7G@0czaAd7RY^W7`bAa;1H)k6n zsU_~0K{Ah-ms69x_A_t#vk-cB{-@s*ul^sx&Nr#)jE~o?u2+ zEL)6yg*s^|lgheAyERw9ERXOA<`N*25&@x#jnpGNyRs*k&AV`AkH<&~-FPZIxm4O# zqfEInXj4TcFB*P#-B?d8sm%Cszi_I;k@jfIbK82I%IWXuBSVo&RLS+Z^XDvg5nC7y zy4k2)yp~##GGXAJ>uvvv^4y*<&J|%-iyqM-pDGCq(d!U?j{M@T zpU}!`ws^@9hiwB{Di{;X8Ea#h1of9M0Rd)%9Mu6Ox$Nmf=On&MdD+3qFYig9TSLR! zD5Y;OJDlk)$DY;kN?o*nmC2(cs3*ZHzDK4>HwXnjrpzHgs(FT|(#(H4H6dC|eQ-=FOVtdlMZaYl}%t^cBze zON$`>kmIi<${bP4F?TQ!WsB6VAK?wB9Y83DEw@$WPT108^x23iC^L&!`|$u;yA$ zuP{{XizT6fED68h7mEw5BRZs+;!lSQqOOs8h+~3+iSv)gWJ@#j5Zq<7izf7pIer#k zhZ1RuSTE$+x|Fh=P*4_;e|a%xEr*l5lKLW=1STL;Kjhhlo7>aUGe9ALA-<|cd5Al- z1&+Ul>qB`v8u&envOeMku|bGo@THmk(#I+2akJ790BG8h2iXmZxhJW9gLmE+dHKs# zQpN}F1%R&1iB&ruTS+LJ!H zBL09jkt~V%?D<`VifnAhJ1qbaB8#_ zU%l1{#ogH!per&*fO~p+zA0J_$I_^=9)KCBp}23g>FC&4KQGV&wE1ucFj0`HXB-hU zKbz_CuCb{p8<=&F;|_?iQE+fniH&c&Gc%J!WS-iE`@0vwLEEA^_ znBMl1x?q-%kg$#G9S#@3B2a|lsK0l)fwYZWN$kJ!x&puYwbQcK*{0G?V3vj|e1oeP z;TR@kucRsHA_*p6JDzEE8~|JxeZPqwW{f1J1U?{+w;J2e1LSW*^ejUWU;A}i9N-vUT6FoIx@jOn4f(~ntu4da{7F)w+2nJ~gG z^0>zX<>TYAosN6mr>~Y!Zq-)L=1y-bK}>Flt+E(tI44|$r7uJk*4xe^D-{cLl(qZY zsnktN$=S)&+c4HGXCq0bv8%-hi@A{(6#Vm7jb8h4&D{E@W~X@7@hHj82&1*2>0FKV z3*Etn+P)s_hlb6a`b*17m=bJO&Bdlc?Bjv4vnP&#bjMv#jFoXe=ww3-N77yS(KYLn z#UMU#I|8owzCkOW?t9nUX8~DSq+NG?k%do}(&@2@T%{b{zNHKt*MPt__RVN2$Ml#` zfNq2D->4e3?bEG=$fToA=^`Nk0U}_8B|b1h^#&SIqWc^{c=Z-bfr$9r!HwBorSfOY zS()0;d{EcqP_4ChOV4p=X%x`*bahBcY^V=|)VjFbmNf?)TG?Bx7GyJP-U-8Am@BNW zCzPCZ#*y=NN6`{$?=W_z>q)np1Q^*7>;D9{ckV)}Gf;15{j#$gw6Z>!`gRXBp;TCW zAk{T=H)^kA6oMDZB@@K#Cl83^1^T*6&`Mj^;~R}Ke`0J;>f$_=F|qEJ_^mAuLbA6( zkzCC&{%$ILFHpZS-eIrEsv&WEt4ry~5!e6|{K^(H;6p6n7F1;Qvsl0-g+e^44_>M(1=ACp3)(54*L@*(+8pr153IlD=OB*baCB@Ob$!cpTU1J4yl=b>lg4~VcJOp^ z#oMCZ>RnEeT0IfzMgg~99Cf*Mqtw}&XF~(RNbv|y{QWCtk*`)mlFSp`0>9dL`p~su z+BEOVB?&E=#qVs= zpHod=Jk4`{KHIfj`exS;(N-{yqE%kD9pyIgtMBBef&*2pLf$WrQ(wT$W<#VU);!c( z-QWh`DURkJKAvECbW&;P5smBDBxPrQG6+3i9v|^4lu7Sof5m$f^Pem=kTNmJaX#pY zdkZ|D`)Pw*z&c+TdMx$yNQY{6y#v6k?hWv&3qO1k8gq=D_<7wP)X1 z{g08uc6B7kxhm@v?z^?);jk_fCoAknIJg`kgjZ?=+8(auT*gb25)s7Xh0Ia>W&%E{ za%<%ap0+bfqTc^h!&Q zZ_-{ck{YrZ{)1V|0y@h7WW8k(!u-*$!O9|N8!H?GdR4pT?2`m;R z6&$TzDB&u$sHAELg4-3mu|D^i$mc! zGa5MBvdU>*=dcX3oTQ3`(Zx=i-ImLhSt=np@_5jCl}2y9ZG;3kqwm;L-Q<0_kbFn; zwEM>PxGxsSGLs(?P9BUsu79cV#N)!LtM%$kdm-jZWQh(5&SxmeXx6P8V7{I?(juTK??Gci#u= zT8^!|u~EyYNLhbqhwn7E7R0J`meQ$t`Ux`Nf;xxi_IUj8 zu1lwb^y-|eGWQcdLqjUnB|WbOvb#Bqxf)d0)(#~y-+^$IRnylkSF)AYC`XCtno%|V z9xaj8XOqn-MNd9p=_IjhElyA9bJyc?YVz2AZ(NS^dUA% z4Wf-Kk#yT10<<&-NF}S|;}vx~ALN?=jZhP(bn2+X&<}fhC01nWB`ASq_sg$YHufZH zFELNlYBCLy#4wbxP^*xg%GqzMPN`t6+TsVanZaX`X>#?Q!lZ-2BI8nEXuN01&F?l3eFVEN2-^(uyl+ zoT&xXG@#jLiO7D}AqAmp;3h zxxaFL#djCx%vdg+DK9MQm!f-%{zh2-=y91AEV{-)o|iO};t=hm zIOdXv4bF0Z;PY7=Bv3PIzEQkkv)oT6eZ8v5pds4tDD~A3J9XcZS= z)lB)`_|+*R+`=MUcb%D-*EhAN9$hGodY;1%=|930#U zI2Ck&+IxLS+H-LRu{43FrLwqpMvk`68iUa&ULMpTyH1I-2MI)&6T5h~QQ;9QhT4u{> ziYuoc9(e<`%%U~E%SCT@a@8y`=3$PCAzCX2+!caNxNZ3!f zUnpv`xRz8HzQ+cCtooEO(#qJDUO(D zk}~51gS}&R#QUFjR`AsGa~jj}!X{`bd##LuI%m)CN=f&yixcCLeP1wo%_SKrUf%Kar@-yEdQ3Vm9czH%=S)MY)6MGi-sML+#em5{nD<_pWEzQVO znHEPN@H_=)HdV60=KfUO1alttrtKM_3)KaQb&C)mKLVJ)K+MR91VrdXfh?q)2;&j4 zaNs%s1oo%$;daN;!yc~owt&nSM^a%_%=BElfA^-4JI|q8*R%AFWx;YU%Wz-FaO4d) zEYwR?W>A(zx?}Kga8EyvKyRp#5)4HKCb*`Gzm zt};)zopbZ-7SO!-yca{w*v7{z5asoYgkRL8)gT(7wGR)e*FkDeb$y9X{QcJiEX+mI zWCydw9*Q`Mn%na>CeXW9Zc!M@o5<@J+N^zACdzEPIQzH)GoD#av^0n%0Z^m~@Zt@R zE73ory=5Ie{Mk@vLq$tdbYG17uc&e!xC>v;dYx|*Q(IyQBAQY&k>}{$n>z>l(jU`( z`|7v_?VMzqmmWU!32>8x^Dvgh9Z?weeJE&CTSv~@F3|4NYTMYufpFrnD&>B_ZYT5z z#!_dT)XZ|TGhO=nr*3k760aiat|3y~{1M@>gf9XA2}`!Y4WCrCm8wB&C3vdxl0=+} zKKR-C+=9$t<+DZNqnO@veBey?7r$h)r>qcJaWca`-NB8!^9#M#fiS@tj{ zFoX?#i?DjFSIo#+_Wd6m$eG7Wp&=(N8?ipJ6S|rdi*gwixo~ZyyyvsCb;>u_XyL|- zP|PMp#&!1|T@z2-P3s*k92=Wvgq=&1*s^8MC}r-4ojt3cul1}9AQ9-9YBit1njbhi zEsCG|64*W@Ulz0~6Rl`Dj7jwjh2galU0(L>1WU?_GHKw3T`1wq1|2l1rl@7TaJ&k? zW);*TCj@ddJO@l*u8J)!-W0Q#OV1>?n1XpAVQ%CWu72jw-OWdiwS{K!%%o>nx^M92 zVSadG2;96HU7O_Lhq+lc8x9v>`84jyflnV|6D2F#%dzCy=QQr6+TD9pl=w44hi!E30ikuIQSKB76n<=}}Aa=vL|*>Jx`sSC0_vZyBSJ30fVIDXlU(PciRcDH`*w zXx;6!nKO>g@#u_&{;&ezM^D9e8`GWl8ug1h^F{{xE;Njpq2?P;qi~MDRXaLC6ZbHO^XmgcDwdnUJ9S+@OA+jxn0Dp}MKUyhf(lzH)}yIG->} zvH|&iMbt0KQ3k_8RaJRDMGbvjNCcr?9Wa+yG12*`CtsRwcb&X8T5T&8nWQ7GulM7Owmslo;H3=UqR z`h}#JL2HC7j7gW@nJ0z1u`ul3E5&Z2@VF84p5R0JJ#sLY;Uxs(W9V7D#&HdDoXsJ4eX{QSa9ysI)AsYTDp+AbN7&Ox9g#}lAmC?dIUf)FHgK2wV^$6AMl`_9|NL|N*>i;y&^W~KKiG-8Sp!2he7Ii z623z7RH!onZ?YamXAJ>}zbYSXt7GFH1HuKGgoIswgoB~sa>8}Q<(aMZv0q(2nbrlf za{Co%uMEsgMS5}E)$27vt$Y>wHFAGc;&go|d2#yU+8wmHGQ0@kSW83|VB$fUr6D-D zxluLn{X-OgREHsm1Op!@*0Iy-DPgNfribR`lnr1W`}C8AIWiybw4C2zCECT10?#v-uT#8@B4e_ z)Y&kfDfJf$6<>V3f8D1oainl5a#|Q#v-0!CJ!`fNnkm~Ue|GVF?;kmG9b$Sm zpIkXiLY6<~tIT-r{YH1L9jv+g6Diq3kvDy(zx-9(;k2iB0>Oe^CC_>O>F1Emr+O`Q zirtd7_}H!oSN7cxzPxI~F=vW(EvTy^WQHbQ($_a?3-;tXJ8q!+s*V5}XM?j2njN43 zIXk7v^y2B`aW5^aq$ZwXl zq6m9yBcld9{y1IvM~L}ZA)cY{C}wg+YWk=Krst>Js4qiIkE6RnL2Mo4pDb5=9g!t5 zfp~UL!QTp5Jc7K0VgzR{ud_XU%srcU( zlY}h2mZ)O9pY5@Rv&DsEGEEpnDzh$Q+G-p*bHI_l9Jz4<6>kb^3CYajE22n5a8#`K zJ=n9KwcpkjHj$rD8lN6~Z?bSOFX3^a?pRsHaevB(?tl?t#u3McaIYy}J>{R^^G8Y4 z2w=Csy=oarxO{t&cA4?Bgfo2lSIClx{0M;0O8@Bo`pIpK=qrYn@PtSPh(kRaWXI<5)hDoqFgdSz?dd3E|H*g=UA1Z0WIC->>y#)Ak!9nJuY@@_cO3 zSJ@`~jM%Z#(+)dWiZbJt{=vhx%MzOq?=58_Ny$r{=N)ZLXpv}zSi4+W|D@{RK zdYYZDx8`E&-3z`xug0BQBrOqE8vd)ny;HH-^TM8Bvl%>+AdX8B0+zZ8`;)LliS)J| zrL!)%@&{~9ax}9k+EYqOe^Iu_hpu#E z(XI4R#pamtu#N4GE@hwOgLf6Kz1~)+Ur$MJJ=J1boy;c+hBS^jTQXOpvG-SMMWki4 zNJkn47ECxXEFPi8e+~&O@f`)`KPhdr)n$fmeKH})(=X`WDL)iQw5zv6`#oLFVqbzC zo(ikzM?{ls5~^JC$QKZIJcc_ChSXi z?{nFe{|evImYsILS9S&7=YjC|oX>~fViN9kyCPpD6_{Gi%*|6S8J{&?m|bSBIPFh< zrU7?Ja`#Z$ukS|A%?9{WzQW`YZPZY(>FYt}YqMbx2%k_Pn7M5APrN9R%L3n5i)X6U z69FWvTNk)OtMu-iwyCzlGlEuqFJ9 zqP4kKMhr1Zz}hp#*$SaYyynJkkfEyv=O*IqjHv^VhPj=<#MfST);;(Rfya*nbsox( zUzjMJfmfbbM2YUl&1Sh9JDPJSZ-IucWCi9LU5;1xZIZDCQcr26x8XC7Nsn4lzT~sV zgHsL0kC6PNt)LtG#;fH{t3JMP^J%=B7kSQO0-&YAexk7n%rUSeVmoQ3&ZMASlqIPrixl)#9= z`_#ho`cRS}B(m@H3_tw&ho;f9?IQA(x2}sdMqoXmxx%xa-T__rtMyG|@~l;Q^&n%F z%u?MEB;{5Vc4q*w6ra96yL>U9v^WWyJajEKEt z#eI={X7XU(tt%1D6JS~!=w zd66q&Sz!Q4nexmZYUU)< zqw?X^UpoyiaxeK8LoIwIxpB|p#|IpL#fn0F!ur^1ch1LHlt{VVG$WShX&r97`n6)$ z!TRIiU!#a^^5YM3F?{cR4Z2_J-*=`^cXzQ3b{>rr}V`@}@*X82O&m zh?Kd>Y3|GY_B8tCa@IwzSzIN~#@V{9g~-DQ4;UQ|pZ1~!aVs7r_St+fhIkz@U8L4{ zrd^GTkc?%B1-*y=flF1;Pv^`(ShN6&HV;pt!o8wWJ^$@hvX2R?T<_cVbf|NlYpkd8 zB`-qr4^jW@hgU%s5Kw0e$$GoBwbe78HFb0EWxXO`&;=IL0Z7Q~9vm1?<%%7ywQ{6q zWFYMy6>P2{s+C+|uj;$5T`Uy8^T|Cg%Dp#ZVVI|p`HUWOg(U}nt3e`E?<8>1R6C23 zjq2ZbOdBHW1d2z`QtFKm1f3lmxIiq-A6!K7pU;EF5J-!4B)INucuXNup2$3M<4IvU&%Vodhg+l`yV1SkEj`9B#PsO2K{24+3`PJy(aiyCJ+;4Kr=Q z(YzteCxj7+4?R;C)j7n742PWz%oi2jSv2gGN1 z?mM5^^%1d`S+7!#I8m|uNXV8SZyxPPL!QYZS~nlwX#1BUl23oQAiaK?(g|dgl3C3nFYhNhz%(#?vEey!Mlhg9o(7L!W52#1 zs>vI%E?xYCuBP^h-+hN#>LXeI)ZQPtTnZEq6oPczP;HCG%jduaD9Aqozxa*?{2L|v zQy_Ub+%l{mm%4Gwh*50$A>`{^g?F#mh+F)=*B3dvfiqU7P{N3aR2k&ZWNR0;*cSlh zYc?0Y2abpUfN795- zPx5?EG_?8{WY(&m`_`%4zj{jeCAH2W6J5-;GBqKk3Te1zV@Qqls`Ct^I16Id(JFlD zE%A_D^Q>vQnk72q&kB5teLCL@k7n!EgxUC&NreA&b>QISh5IpZ0!^RD0+g&Ue}3> z*AY9XOB=o8;4b^KtEXRqRTv<;3GTBfbG(9^J}>noLGTTRX2)mngO=9N^i#XVJuS*N0arq6a zgH;BO)_7imBZLTMg7Q~=?v*hrm02pP+Q&$*=LsY78Q$&hU(%l0dcytUXagn7HhYS1 zSDPKpm+LG7uw-fkiN5O_iP65d_!YcqYHA8kSEY^Hd%2;L5u>4fQa7}yO%Rh>!TQT! z6MY2#{oqA~PxY)p^L^uwTluFhC6BeX-(E-T!w7I~b3d6W-u}Vs{oauAVl4FeWSVl@8wOOtXG3< zIzm;Yos6Rlfi$&+by>SdcQgb4_CToZdv?3f^Tf0%?e#okKxifFjYGf=zo|mO1KYwF zrJ1)y9aa|h_q(FBaku|sv++LP2>W4<+*1^6R5|{|ztn0$|HY3=zk=%ybk=3Ms%k6!`O9hh|MJ%Ubt1ITH&FGpAR9kvdbaj|A$}>)7^Gqv z@avN;-3bh_g9@KqXIs_&)y)3@rr7epZeE{Gy}bTkEr8NepE2{|`Ty-ZKwm~4X`io( zIsTIN{uA-eQXD+tstWgQcKLt&=)awq`Ue^&KywwlGDJdzF z^Qp$-;^GiTdTMH@YMGkhVzpthYFT3XcnlE2o~|}y08Si$&=EH994UX+e;~c@)2A*J zxKNAg9nJm1wKMTuOivubb2&1wc>w;XapKd_v>qllR!_YU2j>Fq8p8pVK9_ILs!S7< zbD-7Rcy8HgA9Y{Wn!o=lV(AR>c{iV!`{>mi_2@*R?O?MFKIgnmq<{Jk^7M_!(|mlp zOACmJLdabfwh(GG+MQw9)}K6vCtGty)Ydvy3;)EM0FdM#Y!T6N3yukrqig=jZ+_Bm z40czx`eG`3*c051u`Hp6E3TNMqtX)#>A|wpBcVaxYioR)DSr1RDLp(nh|D{A2=2@) z=wO=;d^)&>!JN|ESx03@asz3Hw)w}O423+x_PZSc=|8X&-wsl~lT#Kz|CNSxng zY(K2GoRfX|NVLkFPK+P8TDS<|~+~H|RD78XRp;eBlv6^~rwtj{x zCL=R&Q*O;5XNrWC4SM0Ca1OpWid#1I&)$v0Z~HeHoN0{ht#fr)5dGKnc8hrDv)K=J zxEEb84(-x8unN&xRE=AE#BdWWAju{-$f<)~==tSt*)A_!vcDs1)S}1bjN|uxxp!)8 z4M!nI^lywD(P>Y}!I&-b z29T~gXm4iFW8956(zT{8^2(AtQI;AjNq4rHuU3vwXJMQ{5wru|N8iOnWa>VL&`tP8wVdFxiqqqoRV6)Q9`ZstZ4021iDU1Cm=c-f!-JAw4c& zrMF$lVZG=JxG)*_hhpKde1m&^do%e0WcGysB3qvmAh!h&-hKjgk0B=jWEOT4j?3ED z+*A!l)#*VHg6+~h-%E_V9wEy8{-@0mx5x*VDjEBNoh&BJ0pD;V$~Ps2YdTl{6ZftJ zCt9A(iJJOsv)4D=ae5y=euKI&jFg(B&8a5FD)!5ea;9 zN*+&j)zMi*Q3RjO&Xn*Q&hPAw513~!ycH=eH#oK>ChLXNpVmt;6iL-yw1k4W8sN>| z)BW><-|VawMK-ID`PaC=_G#&;7}&iYuKTENhXMz|$^wyr0|OsYYOw8$CphfxFTk|b zJ4*$7xBJ7yp8gWq#$|f*6Ck7Uo00I61U1~*V^;piHwEuH$)o0pJ2nr`)KMFR ztlveM{1LjGp!h<cnvU^zF< zhofosE;=%BHAF@^wcMH1rBCSD;uIf@^<^Gi9SI7Fe0MNYbUviY@MtFU^7&-lJz;00 zPb*k7mDY5AoLZ7IXf!XY_qVf|d3yat!-&5Q=ntvlZM&vOe{*p)s*g_t7{_>KU~Lih zs;Bh6kJ2c;_0v9UciWfKMJ(uB#m@7YCx5BE1b>Wo8{dpt1m5FW5knzxc1Atw5zF%L zcJ70wy9fKZ&UVK6{ey;YJrB?``Z8hRMJ}MQuK6=VsXcaY(5K(5(=i!|p>@AFIZ7C|rL++pi=4~CBo4e(0^=Nw# zTsg00?*NrxWh#aTp`ySzr74J0XZhi0Km}xMl(pEc?FB+%e~ux3=??A4!D?h*JrJZg z(!C5&ke|OGG~D6!j;z53USFfz>>Oh=&b)AibtHsES9@TkYx4udW9=!uPxRcFrsQN< z$)3{Kj2g;53!Oe;&jJi1BwQqO;ceWeZ z7YZJ`fUa#%WkNNAH4JT|)`Q^NUf3RKgeGX3w8|qi)wu`YrTr4kN^x7OHlHGW1-)GO zcOPIkVLBWcZ}wyB_W>-ot7&OQsbQ?liu5-7m&LEMmTWo<9w(PsUsk4Fsze#vn*pwE7p?$A?TIVHqn1r97Yp7jgG~ zY7dU+iDsO!>h-3S4jd>y9}MoeQr6@R{97t-)%L^HiK;xqE7*mhs~cs7>~2}H2?kFh z)ckx?1o4oV{EnC=j{Ib_${&qUCMWHm=To8R5MmWQnBvB~r8? z^*L7;YS&>734kYXevi`yIKKBwF6lhGRJJ_+_{+3IqsMUzUxAIV^2*smdr$^U9#+6l z+@lWH9qM;Nd%hy5W=;_K2jnqI5yGQBh+yol!3Q{Hq|&oX_tiK$5A=(CG8a_xq0!;2 zPQtP0_(o0Sn9(vQ-zX@H?1#%@}~Adqy67t|q;2aahq>z27>x?GWdId=~&m?_)UfSl;{e>uic8 zn>gx=oX-8ZM1p63Z2O#6V(&~v>4>1Sd>dOvPP~jLy`VKD-=Q4q?26!-<%z`cu0c{{ z|Lf13&l!(~$Go_rHv_}F3oktmL4Zi_Jj@LXX^$umXomme#d=RUVscn>`S*3lE>2BI zUu6x*PgIoXLbmjN1+}o~4+|2nk$ab?9DyAP6?F-pA8kKt;EMkya#sga9B0QW4_HcZ zAB;jQ^bamEpdT3%|H(B35U~L*cEY4k=Qtj_#12%L4RC z^NXAS%?l&h@1P0^0pDnoCSPfdR147{;WhqbFdB=PX8MVkRp508yZ%8Nopi5zdR|jo zpFxqkXBi5aUlw|1f~zv0Qa3p%ARY4zLiA`_wNnXQt>dME@eS!AaNsGKK^lqT?|MA4 z?&4IL@N2GTSufmUpnE5!Ue~pIRHM7$2xoVLKh!_}Pp(vsrB95?M4EtyJob~lnWV5? zNOB`On~>>#aI4_Jm2EV!k6Q7ac8ohKtmEKvb^`CGBSwbNJQNh{pEqn#`Q3?`V?clI zDkP`%)3e0zpr$-Ecm6(%1IDAyv9}K;%$3)83g~K%R#e@$VJ}~H?(TA15d`a;7ssjohQOkCIUnY4rA_fvkEOF~1cuy{Ae!x0Xq^%0B_Sk!89zuz^nu z$9UJ~pM0@@hhynTCJ4MdyX~$|b@Xav@KacG=6WMt}^7i z3^c9CB)~Sl{A!W{XiFLAW0F8(1FVsHO<;7o{@%56K52x~k)KaM5)%3r0&uVA{z#_* zrmlX5qT{ytr%<<3uA!NW3sAL2O@#n#TY=<|+>+6Q4$n&?8J43@j`MhmwQXe%iok<- z`vuLP56t<5U87CvjurH_HlMvFCt>@h7+-d_^=&s7)gkRH-mDcD_?kF8c_J~CUdXcU zgy`xk@GCE#A-1HcLVjYSUbCVbT3k?KPfZGg!|H`Ii32y-c$0ZHP^l)7P#};=g&nA+ zsXBXw2ogAxGo_X99j^uFO^sQui}6pemC%ST^-+7w%I!ZzC<~!+q*V2)2tlE`h7C%I zpq;)S7vH|_2-cte*gOR*<~9#zM|4U51ooJ5{TxG|$ClnQf#Eg}{${Xn&Gn5yFZgkT zJBI$nGKS$zTHuu?7DS0b!hLOi<)`|-q$3EwpU*Is6P!$C3Z z2Mf2DL;&UlZ-JpK${Un@XB|at!d4GTpzP02sy)WVcJ#knT}V}Z5}zG|Iq|+)jx-dn zl#7;jhkKDOtp{B!fj7K$I`FDUUTfy2$ac6}6Kb4@<0m&CbRobwr8Wre*znmqoCs*V zZ7dm#fv?#a{Smi*v9IAg7q!H`FYrtZtPK$oJaJZ1ikSG?uVrko_a$4U9HrNB&%WIO zCAeU*s$s+;D!|!TLVe% zgnr}VbQ&Xi9&yC2nK4$1a{HC{zCCD+A?rPX>ibe@@6Vs->i|VIpN6yg#>PQFqTMOW zB;ZQ;S<;T|uA5f9a`Sku>~SYb;BrcW*pCn#583pRvgAEO`Q#VgJN@s42MdnWWZ1;p zcw`F>c`}C+iQYL>9!kBh6*Fd6G`LR3mqC`6!7M)KuCL`&zjG{TdUZsr_b ztWwwbJJZV?0oB|OSkk#pds)Kz~kCn%pAR9H$K`6nX| zuXis#AmKnhFD49?^YXic5^g}S&pNB;i|TLzS3kP4Gt=Ps{qOPDLZ?G2^4jGcRrm2SFJl;jblr!1hg0J&Y9S$w2>(F^GWfn!3Au}Dls1%?`>LMl zyqKKUDy__(>g6d<2i!+C0t@hRk~{9{`!-!RUYM_G-H6?Ut)I3ZKjrB8Y;E#?qGu<|Ni`08}$&|j*7)G!a+khms+n+N9SJI)P#_dQnvA`hD8!@=yiG5ywV)TB! z(bwar5}5HcSI`g6e|r85x6RvdA@Dq*Z@y&9wCY(mEma~XI}YqeFn9xx|Aw^JI!mAS z$#ya&f(*%qw@e6--|AN<^GcP71k@9g3)2^A>Qv)|IeH63e7bVKb*L2zHC%MwT%fz& zdZV+ig9=LHlLj4^%oe{ubv&~@VZ-&E&1E4td6$s#+&G%A>Ji{kxwt=AnhUY(k0ef* zWc>JpyC(`T%DEc#_Q;iiyM|LjXRh2cGR|TpaK~+l7y)r;!k}x+vntT22a&mr-?0k42!yZ7O)o4TCmihbHN36Sb3~tLw zz-WipVaO!{=UwqZWjM4mv{t$(`u19JmMqPKMq{}-2G56qPSPayXS+)^=ZX92ftbJP zj<+uo7o<$Pv&`-dz&5))<5VL!mFODrV~XbRq4`FogQ3{)z@Y}H@)>V2&qe};0Qdrd z=bfM30JhoAHbRmw^6rfhsbxJuWPDlBKOp0V2#L6_4tUefonms&mHxQcUVd{86~{j$ z6rE$Kdz_EUw@nxER*24JyV(*0{}+1w!AeL8Dt0&~M|M(StG0X6OPvb9u{99}LJ6Li8dG3a z#N|mC-F=)$jw&!_b6BsX@S||!>mD_>hV>hnRt_4EIV2Y!QBtm8c;6s5)wElUp`>4y zxUw*ncIuo5?ss41;YZsZ(r(eYIB&>34&Z1snO-AW`62iT*6Kbd0E=|ZTT#F8*X6)h zC>$Pu5M5J~@cl22UU+c4c??7BaJbNBCIHOi4~-02j4$SvKhCBbq7--{HQ(uhZA}c- zJGiPPBL@Q>hqp&me$v{|zPEKrOW<`A2_yP!$|VEcT3d|MS-H+ajudQ@Q(*u%_u^B^ zKsU6T6vf4G#8g7Gl4uLFXD5O{pVJHM~61LaggsH9nHPuD}tD6srDY_Q`8kN-_^^78Hu>nRYV z$>rKb)`U{>R7G~=VY?)*S=xL2T9m|#Kug6q>#+;Hi|Uvem-{wlU)rlD8fJeoX)8?e zu6C%Tq2$3n-Vw}S{cCwua0P;rl;WDn0(=#X2Tt-Vk*^izE}Ia_=0lwC>nQo^<{IR2 zJr)9l$se@`EgFp{&`oV@2;~QeuR?0@IJN)IiIb`;FgYEr?}?#!X_aht{x#5cW&N6> z_whcnnv(IgJk*wOQolhl(L6<>VWgSk9s%p@_3>t?`4PL^v+g+B8b%@B3ARla{{?ol zVpOs{XcXEU+dpmtLDpYbX=|d@4K*Oe7+T&qvks3Cr^q*ffJ0b(tquH3{0Y{B7bv9ojYPF2)i=*>5f#dKGs?5fm9G1kt z$EzR&Cj574znqI0Zk6iVaeOq;mY_1~(S}!j<^GXq9ba(orrjCeq_OPHsBb;yay7~S z<#<};+Y6gzX!tIPePrO{ajT0?kv|4b_rS2w$uX`aPlz+b<|lN?y}f}cKRdfxpW6vP3JS{B(8S6LKC#a$v!3e~v`iX1l8zZ* zFD^`|fXVO@r4q(@`O%A@O^T0es|Wd$*MXH9-*o0!zzo0X@)PHM!Mt8d#*0%LK1`ww z%xwmA!$N0+ag4?nbCzd>oY4FlC(*vbhoNpwN;bIV5dggaJX=3)E7 zs9t5Hqi4lSVL(Ab^^vtZq{;ATc|T;;q{r)?qQmx90NSV^JF4ZJIL}^51o-^)gUSG{ z(l%`*M>C-+$}+<&1xk5NgUw7*vPiV-m8NkijeQZGSH3E8rp@b$E;i*T%FN#fopD6J z1gl)jd-!IzHw)bvu_leXB(P>?Ett>0x1m)Y^m%)9c zB?0#hx2?cDG?z=|eVt^OFr`tYE{_Nm>mD`7zddG(^0w;HSC@QhYBzpNL--n{s!1`$ zyc2s}9h(5VCUP61SA;HjbcS3YvB89#yZY4)@GYkxc!$4w5!g4g=w8(fZBM1Ww!n{z z=}5aSWr}mVZkIk-I`*JlPrF?$xTsEseL%asiCV4NAnl~UtGhBf2Oo3ebC+;?n?JL+ zp5eYPEIObbR(n`MBJBTjm|*&8#j}|WTz0AyNmmsf(&AKITQB8}%*q$=bC$@e*eKq1 zSFr2!yYY!H#c50!P)1dQxISpJc%2K17x-<%<{?&10prse?;Ify-}fx1t3L2<(fJ+j z`Yl6W#C^}CmdGojGgIl0y1wkREj-JGk2?hv#S%PELHJXtJ81p+9uoA-``Kf&71d3H z{B2p^W~c1-=!!Xe;;cWfBpG^`K3q%xwWKjz&*J0L!h14AEMM(XGp37W0X+5Naog^d zTh!?9#P*AHH7T+NP#;+?HnlP?7Ob3)YHB}akRipFmdA03lGhoEA|3VL*ON}~yX@r# z6JbH1PEyYs6ap=w{c)wjm7#J&mbkejl@q{aq!`c+Ohj@pg4UXwn9#d0HqTvnj&bYE z=Z#*L@_z*aj{*_iX>Yx-@Q&ec0##z*^$ufR%_MP%Vi#05-hX;h@dlGl+vDqg%f_s+ zy_kB3TN;yDbE>G_Dl2#Zn-ar_mS-kZ&jUyE|L&y{(TQKY!vq1?^W=sdRF} z?3bxmjaxlKm(f|72a%o~3@0(`(`fk+)jySX$lpY&72kZU9=V-!XI!)z)gUMN2sN3? zhf$ypVU5_V%RJUmKcE&Q^tC1F^gGy>#O7F~npA`=cL4~i0Z%Xd2d1%wq+#p5H0YMm zKbVFC$0Mr4yBK{#7<1B&f;R9`P=oP^iq}iz)B=TzTT-gfinP>vkXhmRZa6T!HB}60v(j z@>rS0OH7^W4fOxX|LkC{mXYI&a$sirL<)tF-p68jFP*a+CE`du+*f}u#$Dr)fK&1_ zen*J;(@o_n+Gq%I5pN z1Qr<_3>4%5Q#Qk6XXQtZLWwDHCK8;#*F!-9ZTK1p{@bA#i-8a&NJ(izdazIJ-_HN{ z%ly)Z!G@nc#_Idn++R(kBuS*#|F>5D?P?ig-*4WA$F_Jux&QGR%rRfoS1d2LzTpr5 zr#=C#b$x;PwP}mz^MBa}os3oG>%5}yL zLBMwn?@yQToL*mFtDR2N)_oE_t5*LwQu>>FGoQ+R1eBVeaPkcvZzcrTmYsm&_>0bW zezi)SkhhlyGW`xu98Nouw)S=|WW1<>f_N%KI2mncvP*6biet&ZzCnvvbfrLSv zJY1M1q%foi7z6gU_V=fn<8dX6bt|;cv|FR69aH@J%J&Js?goi{b{%UZV@PqBji9vZ zEluF-7M2T(iVQ(@Th;WsM0u266^hB&Kp=w?pNfd^g32VzLxDk zP1e?}ZY>V>IDy0#so2zv8J$jQgBp+WOzqd67VkM>9d7(^;efF0pb4BW&Q@mJg7rkx z)0ue-!C1vA>v+-~?)nu_i%mP2-Y={;Hv&E=<9D=%do6;GHUc4aw+A>p`E1V*3VgLQ zOQ3D67yH-Q@29U`<*S2@(ho+h?`|3zyZuc$6;o9n#`m7G5))gruEZ+LC@(K3jm(}t z=PuryKTf$rHL?24jCr?~S}IqQxL5!c1%Xaz?c1SrR>#eJen4#DJi%qJ1$}-A!DUNw z$Vmf>CV1GGJ6n0>dA}vUE$g=KhYUU+alxM@qv=&eJg@yGD1E`0+<-0CQmWL3AK--gcW>_LNX@@`#1OgZ+sM;_X)0X9k|{LY04ss^?rgyK^!}PDsS2M&%r6ttNi#XWp^=O zfzNKcPUoo9V6Cas>KuV^!l>6K2n22Rn_b`xk4ZGjFxZTGqI0FH&C$9wMq<_*9lAxE zS%AEnND%TBBO%`4@Nh^JP}Jb)`Eu6idV7d=zR?+##HjD!=LsmDF9N2J1CfMbfJsa) zfSu}?C&}6DkEQ?*3^XzkZ!nwH{AN*>H$;RzAO)gFK|ui>ghCASK@xyqcl1R0F1CbX@_(#%t5@HP-zseAsM=zNB@mp7_5AJOoD z$%wizpjht4i5fS<4SvAKM4TaAb2bl+t3rhW2DyFWRsSv*NDW)-8{9s z2yBO}zKwffU#^WnKHeOl2pPY6N#7AjeUW9~h`9qI z${Ke5m?#SwD|>vU87|=*AK9nRfGFD*Q=QP2Mp*5{k1Qef4xrJG3XnZgl99S%8&RkG zXcjoh2vHeCa9WTl&Fj$YdsZdSXb z@_QYMA@iir(Y<8UX<|AQIRak(1As`bVxBlshkV)9ZH_1e8-+#I#o^;Zv!j{yYLnu} zSBp~sYVREw2y7%#hZxSt{x%$|gC?{SJMG6GKc2x81h8QoauT74(N6(g!)7BvZC@8B zC$SQx0#}}Ai9|@i0JY!aqJ1#R>2xuMmbGqW7)fTR?Rz;n!FVcL*v4y~xP@hctZ#zF zOk;%Vf|tk$tz%JRM7>-FcQ#5vem;5SBPoTvbLS-DR)5}ij>(aB^2e-iBVg#G(i$rx zWAt>_dG?NDSfgQPQN45Rka$gB4(`nRL8l>k)^n5o6$+RV748v_mK_aX8`(Mdp-C!U z*b2SILC#nF%g(30FBVgU2*QvLDXNzTgtb1l_*O*h`Vz9Qp;ejVX)L1L6&%04Ry4}y z9H6|Yv1kn=e0A|r-^mlTlaui;R|bx|F!gS8;z#3kd6LBj9ZeN)DEF;4Y$6^*JIEg| zjQ-Kd70jI-Vt#KzT^`O{C&z^Bnr~x+4KX;sZK#R0OU5LQ8=4L3@o#iqPs=vGMibMka-t;-`l49`HUWwmSJtRtd zgnl_2pi?VC%%Wlwsd%1y3_^~tW(d?xgSaGTCt^G1zC# zJH~MWmK`UVVOnt7=CbG$c+G)mXlXx)rG-Vj9-$zHH?$uHR5R#YpJ}f_FlRd^`2VMKNPfV;vd?P8kc3R6w~ZPcUEY+lsDygCWsCycX32u&0{Q0#p5i##e-|hLKt3$x3gz{%}hV73KpWO`+d@c0IB&( zYvs*PU9m<-O2r_HBW9glO*ef{y!<%X;l+uf`~_LRn?w5EXlSt0a9M&-bCh;&JCPe@ zf5_!%1sv8H`qL*6?uE_C@|x>JQQj)snlTbyN&M@zP{rtQwqrG|Rj+`GDCM(NlI=iJ zoOjn|jh8~Ufw!wGyO@zMdZ4>ccrR#BKW(rmGDBi$YlaisJ>-GtiUdA^Kn3CZk8IDq z-=pwb%#5t}SxPl{J@(Gbv{+@_$Y(F&?Q}eD6vkuL4Bue!=f)EJ-(xAH1-*NSJuig{ zJ)@&VI%dK?(hHlLQy0pn6)Ijl1BNNArig%=0IfEpxys~niFPByP}F->C#VPTBDCA; zfoABlQr$86VKY-Ce^7?1#MMtq;sXMTt!_jLq?7jd$Loybnc;>^g!p1!{Hm%LiK8Op z5PaQpu_V8n__;UJJT*ql#XC;eXcrh1czmvXp=@=%ol}RH2dJkuej=5K`A{SX$s94I zbC$ji2?Tbo7iu<%K5xgGb@!(fC}L&x8+03yImTw$;)G_NV=^eVI()DbmHchvzT!IOV;c-sxZMv@AmhUa6pa z!jc$g(~{}m8){OKg24%CplO@kA1j}b$tXQ@>@$#j4eYRhWJ|xMrw-iP9AwMhcZTCSs3sm)5H{Y_Ii2Cmdt(!L zGU_QuyXGaM5MJGJ#iUaKqk@=lAokKWt?abfQ2O$H&!3TBFPw{*aT=$`XGO@yZL{+X zRUk7pmhwRY_F8kdS9<^4>Z{UcHD#tF%d#KkK(N_T{mn^dSc5eUL33-0VIrXr#N&Lr zS$G4O#?c8TSsE===@SDLcp|l?V+uLK&>(v=q1Pdw%-8dSiL8m{Pg$PuuM1`Bj4~PA zs%eZFFMzS?mKwyZ8W?uRh^Z_kU1S3#Iu9%#&a$kIHBE>?zZ@}&G?fC zaq|3rX=(T_>M^V8i2VlUd=m`(i!Te@`&m9|(AyqSu*PxAu?;#}1(YtnA~JTJPH&{@ z5T2dOSv1SB$su+^GI;ADPjIELkzq;H0Vo4w!f_u7lMA{$aK?XiygsRZM2Ite%Qrlh z4-Hs6495Zu)=aE1} zGhD&7>S$p`r;t-2^<7)5d&FaZgF&Y}x5*GYf?ZNid1>Cn^qV zkLA+l^PCHFAk~R*{oIAT+MWFntp!w@Yt$Vlzig> zYe4A+haj*{@w1xlL&T@#A9H^{ISOG%jz%7gV7-4|X}FrZ1KhCf1$-dngayP7iiV!Q zCCY=G65+QBuMHj>0Kd_scbvx~=}dK!8n$qmXVp#GO#t zQ=R`i04kSYU#X4(%D)Cf(VwpE;m;+&U#JzeLpic5B%^6-5x2OVPjLkFwJ4C_;CzFE zg50rc_kmkk3&EUS+Zdv{;bq)aV%9pc(&v76tjh0mm`_GZ+JyBdO+~N%`n3BOh-b0= zjN^y9yom)F+DH3$2L0iDFP5{Sw-My}WJX8QFwFFUpY_nc8^KoJAK0;Php~Ji?iIfV z`}BB%=-DtovlPJ~C|^?5Os%ewJPz5|&}XalmoEWxyj&3R@btS;ft7%N@2-4+p~zc{X zlEQhC4_v;~uf9pBSKOqNpNy|VPJuV!eQ#iVtyR-5)?`~`SJ zg|R6Jk{frHpPvC717>uS|DjhvQFB4D7hnB+Q@>C!`<96puY{&VCY#ZZWIszDxC{LP zQohg#o9caE8~wXWq;%5}c?aXq<(u$((B2I|^I+?KAKrnoQ7BILw9jZ|l`rCqsF3gp z!Kw>7Okk1X2|CT&@(@4Czeu7@S|nu2SPs8u=@m)vUy{EnEAU=F6^VM^U|vS?f3uMf z@w}2kok}{acUMzW%g%Nuua*`XIpnM$QmXDua!$l3tr?J=ubhnYp{O0hi1n!CxYQ+Fx_yu8W86+_}oe!B2V;(Kn_ zyL+Z@(%3SlRW!6CUJo8f1VJs4mAs|zJzT~bBsmZ=L)76%Hw)F--uxTF^@@4700YE>EFN zp+HEhW#l%+$g{Xe_l0r$ z`pyodIRXKzM2g@SC172oH*d3Xx5P$&K}qrK^f zRRR&2&IIKi&4j+ad7wEM@PD4@g?O3bn=i>hiRzOV_H1owXEw@3K?t@C;NdvGnn>Ju zd-P0Ya^(E1Mia%zuQ>|s;f|!A#80@n{QFbGdkt)@fYwly^4>%z?!;;&{AHI<9pf0y zh_dB;3p2(oes0nW6A@x{LK41~JFeMP_K)%w&U2pAJ6F@s&OhGOx#xFIWzm6^vl-3m zwl`Vk!2i4~$xVuA9|05_TB*|vaeXruLBqzDfY5imXkOn5teWoX({m@vdZ}RKebOPn z?2nMZzAOdH(e#kg;F53YmR7aT*^nhDckbTEn7#raLhQ5aG_!?t57S(N4n+RW{&Hzl zh1dE!8?f|D&qZ~WJ}bWJ>#4UC4IKxr2Y(u)(Onpf-ibX4*S8DYj#N9f5X+3c=acbw z?@%9MFMWuw*o^(q`TQy^_w^fm6Q%~%4C1-i9me*K>lc^1=i?oVELKg_=*{aw|J;(O zw}u)^V~&T8uF}-1_dPzD2h8^Q=~9gkZgMFOy6tNt*ZYjU>LY8DUpvUN%gzn=(_W6T zzTJ*|MxY#0a8FF`!R$V&X3!M1DPt6QqsZ0k&O5BSI>U~s_X6f`|_k95D5Q86LIBfNP1q5szPUX3RUb#{^Y$E}S zlVM#7a&iM89&F=p_h|#zOAbEF%*<`!$6st#aUvrl+jIp5!O$@<_V|w%s-ggc2vQ&k z35YlKwtGA%RO#`3v6vREH2x|cqVWMmw-(qw=jpZr$|p6<(7CH>AsTmU0@oqZ*{vyoM7Zh5&X>w`(G5R5qkf_gpUmMaB0B!FzWMKnf7?9otM8`TD}+ zfRgA`>tNl3^A#k(iQ%4XC>x|qhO($$arZ37vONT3{2p)W%ZhfaZtf&6okYd^^W@dB z`So|U7CamOB^{3UC+q{MVy!nDXX||2 z$+v`I3ylv^3>}X6Y9i;vy2*~qP6b@3?(t(->`D0bcEWS1_bu7OgEve$39`y09O8FW zLLt9SW@lAa+dZ%oXw?(_%+$n%zkHYjf(P^Ya$LYJiy1JGpwp=KZH&VN%a@E+s?e%` zOqvPBqHVflB<%Eg=Lc%Z#$0~>Jd&nD18gSDXN%Etb93#047v4kT}V_?e0<-Z(2w-+ zW@cum%k>Dr&Zs+{M%h)@QUb#3Xu80DcL=Q#s=B@&IzqQdjxe6C*&@*EWXx;qa2Z|UXFgT1P zLNoxw6l50{7f}g`ozWE5#TF-8bt4m4BAQ9+CuoVGv6dF~E`&nYwB}8hriMcjVjmo| zKtn!N_}@;g4B6jkcnEmv5W)^3gQREXsr7O(>FDU0-KRaao0qeznjk&=Q+GP#quUL;)Pl(pD%ZZ#kGu2 zyWyh%-&?zE9FFTa>YW~xe)Gj=Np04}nH4?5l(oRHzDf;t2ev^>mrDKKm;>GL!NzE7 zME5tka{*Y*XJCzt*P`p{>7*<{)Fbm zfL@`ULF_QP+~00pIOy(GIY)1-%hC8{g6QM&f-%Oz;6*DqyVBGngm?&;J*} zpz}==3kfd^@`8={=HvK8lP@ulivNGed&{V}qHS9jcMBR^g1ZDKK#&kDxVyV+a3{ev zxCIUFP-t+s;O_2DVf9wd-e;e8-g*1B+wS}Me$rY&6}9G?Ys~TW-Us^9xLHI1C;i0I znXn@=Z_I=^aAM1?$vRdGiFQ*G9(T;aw^VEm8t(uiNI3J~m-nt>K(5~k+QD`cDo8bc z5gRe*j*=)|l(YJKDm&6FrXq5Gz}BxgcJ5oeRB3WFp4fM^io1N$ytX#$NGV?VtdHe0 zuTZ(6-@x?r;TZ;1eM}ZR(T(ZjnNdw!ThtHT8^Tl8V|>g^Y_8Ik-!WZCb zM3LC?*InxJ3W1ix9YLl{)$GJ-^)kDh-Ish!mzZwtVw5n~S0s6z7K$GXvj9g6Q}g!? z*OvQUn6+}!4yMEgyWB=V09rd)p~Gvt)~1yKYpU*kw!#mjl0~A%Q)?#XwwnyZ%r9A! zQ$U*ll!z!ugaQ3CeW%wgqd`q>vZaWKNTm;jBpfn#eYZjcLg_Woy{CHr9=A{?c`FAN z5#z$hayOfG2557Eek7?kMdC70<#g{1C4v&f1RS3qoLK4Og97?pot?XZJ{fSi0otp~ zo!=zWEkmJRr9XzpB_4+MRldje{h$Uqz-|}oECkNURh*Q)0n97803dGg^E=9na+8A& z+G9MIf?*8q%}$ewu#NN17Z1Lfl5pU-b`Abib}>TmrLWqlZgbV9q`#VAf~03Qvbfr7 zS-QhmukMSut!CZZ)yL@G#Y8F3C-6w)C0M3xITWP~4)Mrz@ZiwVc1G+hwj9~@)+`2G zN?i`;)q0L1mkGu(mph0QI+@6nmex#6H5>Tj7QTR~>|Ib;X#Dv;OzyB|fH_FSsffB9 z;Cz~!#Bj>4Q$D0Kv){Fb;LLO`=H1w4%Cs<;=Hac7UYX_)ob#SncG-E>0~#{K!T0`? zF++nXY-|xac-=oCr~5(vgssMZ8(O7H%e9gLuBxtS5vOvwG?4)T`K9|@$PK2u;$_R~ zlBk2lS7SG1meW){UX+_!wZM(=34lr=$%f4OAo5NKAMWYX%kKwnXiZvq;0px`{O{j? zHRC_PYa~qhu$&9Q1N~Rv!%gfYT!~|FGnx!|`J;FF3ochF%|r86u5`)XNe$hnoTh?T zCFI_+!yfSxR6JKLH{s$jrkKq+S{BZh82DJWQ3SQcWu?JACyz1jG2mlFD_qxl>^JK&A`5{HO4tuBq&pV$wG&ZKEoN zfBx}-T)_2PRm?;NuL3FC#4%&OnwvACW}_`3??^Eon8$9ZZBG9Qs-iv-*FyP`Do6_@ofE^%oP2 z6CfR}M3Iv928Xpa>U(Ug6d;Efg-*Pm%8SJ0>Jsi%$057$#lC`d2mhS*4&LRcG~kac zD6p{HU4->;fJ8SBOgS&^UehTCr=821AAN~leFLEmt2p`eyT6e+yIBjR~-O z1;L}b9L9Njf5?dymOX zn!s;hG&+}kmz`I=ZncaL(mg2N{_&%f{6;deB*a^LH}>l@S{Y9X?nz%H-XB4O+gr(( zjC5d4R>x6=%5|+vC}Yog9Jb)IX6B&2xXMF1@V*Z?eJ|NmT(V7RmhHqq;&%?K8dfPb z{nBZp0<}gwDQyykr-R{5Px7Bx0Qk8aT=H{YLL+gv#6(D&s#PoG@fkBIf9GIy1k~jb za2WrAiY{t2vt4j`U~sob+n8AU*%xt;HL|auZ>|o8DM=0I0$ zNQ2h$Fc+@S&RlFbcJ!(=P{5Rw0L05vwPvz$Ln0MBDpYwtC$PTe{n?K zkiq9f1&P~*>Vu7wgZz>KK$EAuGFj?72IMmSi-Tieq*cQ?x$j#v78Rm7Kq%liAI7TJ z^1H7uw6*$e1WlP2cdCQ0S-AE>lM8GWk*gz}PBL3iSY^dd5$3y|=JO9}5PyC%b_92) z-X*;2ulx+q5y^^a$J5YdTZ|eRArExAC8y)+uPH2<-EyxwT%G&k&N)eoH&2+FeBNk~ zIJdx#j`u`FpnPDkiWqxynsN=o*?N}`)Ar;~d4k{?Q#Bc-h(X{C@v&to!{&m)Eo7_)10d%ni)UW=NxaCleQ zd^<)(ck=M`v>P*CQSh?I{p|-cSfcv0k>_kSSvBxo8TtKhixcZt{)uwbnM8MkBeVX! zmvj~tNI*&A|arba(@>d|Nudl!2 zlet9}r)c;22uQAKv|a3ec^R5-b1*O;PXBJV0stg-?>pNw54H;9{=PmVU=?0mTx|Hg z?AVLSpfXJH$KydUyaQ(aAw&>T&?xi88A_02Kd>+}d&Yz_1Yy}vJ z$%KgDX;K(>-4mT4K#+j!S>=c+p)>(E#_=4HiT&^nfH7$SSh8k)B4t9Ov!b)RY*{pV zHjrRcp{7I)`MDeJyPD1y!$0MP3vKqum~ttxV>R{lNl}z)M#_e12HXQ-{_hb-=E*;T zP%u)*d=Q3Cv}6h#OM2kmJ;rrYlywxi7uhBU&ElekHL4NKJ>AJL46R&a?d#`l@BVJWcZp^$8WhW=>~pt zwb-{v&*!}f-y9;d6nx_mztf;?rhuP*rPMIez$dqd@DLx=NS~Et(b9_}j+ke7RebF~ zm*tK8&^GG8T|XhM$<^wQ)+c@Ii~?LH`E)^)?IO&cT^D~7_(+3fwc)=T)^1{FM?*sc z34adNT6jo^V~($$nF!>WDl00(uG+a5;7#cCJ^wwF1*QwS*NN=Qf=g z90UHRXd)VdM34Xoaq@wBAraefwZ(0v*PvA5Bgznbpij-vxq%GsQ1+UpXilx!$khWB zCRvFb$uA6e4r=%iK$J1lyHlZDosQNUfB`%f=U+^q(1J_pk(pLUZV*r-Rwn^|HJO1p znx^f2D?!3Faj$nxdVlP2{?hU3o;s)%Q5QT$2xa?os?K|7bVLzW8^te&n zw1l9qo>G)~?|1^If_1JG`mPw6a>^eQmWWz`8_I;^?a~+NzXucTJYDJLI$lk=w^KjH z3iooGm{YmFw$G=hjSU!$#dfSz8KmS^+KiTbv-IKg#Mkz)jrnN8=_iYb5K{eW-!ol9 z{#d=?mxSOI(^tqvI4s}=`}-Mhx;f`==(o!i)){H~&w=0L&nhjMEHX<;qP@ZNAfe~= z&#xp+H6zt~iKw@x&M@j9igPM{W(EP5CXuun*r@ch@JA}7Kis_J5^H>eKNX3m8P#QY zNCXL2MCrG)RApMI8a=oa_0Gcv+~1;$-JrLoZdjItkiaQUTYEAQ3x; z&pBRHD&E+}mb2MOD5iR%fsX|e@vqW`#ZC?@&0W5)L-T1r(^Q@V>g(V8W%}Iy!qoe| zf?DDWh&>&|X{T?(U?R{BD}g~~A>p2RhToO5@5a)Sg1F0`21rr?@fxm6IOcR&SGLPM zUW88Gc{VRXs~Vt5v6)3C_J4_Skkz#5KuFNu7U$h=Wb@VBr)=IZ zSewSBFJ+G|o$F&w?mUFFEi>fsyZQ51@y;>&*Z13{!EgB$@norffG<=W1+=Fi*4Dhe zy|fpPGiSVa{pem}2HHz5y03yHm{d8rUdd?iWV1>NkB`b#7Bcz_n#i!1V$g3G3X6pI zpL_=7Ai#Vnpwg!-cDQcRP`r7O;!t?wx>XbgAH~mVSA(r?xw{)?;J-B3G$%JW_XRgT zC+J@3r;smwJTM*Txy;f4;~gEQ@tNc0GqW#EwXmB$Q=tT{AK>l%Mxfs!~cE{ z_5?4W_^XIV<@$EVBarT`Q!SqBW_B>pOxe^&zztX%q3ijZ<~M{WE&DJlvI zC=3jYHd6R?9guUOKG#?6!$i5qp|+gos={2qKTg`wd|$7qH9p0OiJ`1WK|c(l0BCo8 zzw80`b6n7K{BH#(gwKbfMNhyJkOR6ThwOt_uEj()3C`)X)Q9-0JD>pyl7wDG18Sz7 zA23I5QDaQ#Oq&$?=Zf@nRqp9B0!k>#gdj8(2omUD48c$SzZY@^m1}6klhttHa(PtEljv% zdSrl^yiGY<$T*LOYH1BeAj94e(BF=g)Ot^XGz5v<LDj+H|jDr$+&L{EA~fu*)~y z$WkPk8#N=T)Me z3H>JfPt}>DjQcYt^;B{&ZsrVioL+9)#>Q0sq+}m*3#fM;*4h|Ix5~zjd~NQF+u3+W zsl`qHl#LEbek6)4b+W^#!^&0jb~Etx^_?E(1+k>M6dHK9mEUm%(vx%2se@Fw%xk-g z^eP=J181APQ(@E+w#Km{Ge#fin%>Izb-i)ssNm7jHuoVg1x--BCVmw`E?N4nUZ#?e z_#o5iML4FRy~#4n3`O1P&p1ye+a6unF>`FnTlGQz#}nku+{qSuf!~jX`l{C7!&ZP^ zj?e{j+S>mWCe$-mPP{VBoz$VJ_IdJekMB(d4_sqi^0D>}LVTIEbm zDl8<&(i~nfNb-g)cW)h-^%1jYE?E8{yTdn-+h(cHjfM$*44gksRB=bZXo>8VAIQO= z0E%zZy3i8=@mFl<>@~gjG^n&R{p2n5xE5g&501oVD%~ur(o^u|(?$lwfJ1jOo$fhW zr)9BLa{8_QCe=h-tPN9Sf@QuY1My2%_iWW2e(2Pdma=6p@F;uxAK%Z$mhGTA&r-4Y65)^Lak9EF4M@}?hX0St&e6!^(snxEYzK^ zD<|7ITx_v?Hve>E{10DCYCqJ_bN9^pdko2c`x^RmXi6aZ?&@12&i}hDrYpb-4XJDW z{j0?Px4)s8I;RV+fQ5y1y55t2czDP)>XjGA< zzqg{A=b*`#01AfYWgz?onBap*OiZjoF>D8j8rEOlDhnr&(Ow=N$m#v`?)@bB_lfMEl`dQ}<=dkpL9u5aowv9^`nOA8N-PaXt$s1^ zxoja3AG4l!lG&KV6aXJlmcXINP@QUS@WS7Tc83De;bFhkRq1xt9^L~e;<@s|3Fln- zxUT|yXSIuTzkp8v_xaPxM1}v)>iY{kfIpF(3R&mtS6n%3sk#cepL~b~^yn9s+Isd) zz7zb^)J220M1o8C>bIN^`|Fku(iwa%m}oHD&Kp5Nx>1r_o#ig9F? z@quFeFVO=E*HQ=JK>$C=O};96wRN|*9|Ks-c2Kfc$SbOG?dW7}hU7iMF)o|a2h|luUJ>P$Gc{;LTb02F2*V%jkeNj)#3`@Q^*b;0q zn)Uc=9|VZjyBqoprjy0Fl2Dw|x%vqds;hs8p7uh3W;kIm0(wwN3W2SiUA&0DAG2PI zDy^)Bpma5zWKa^oNBcT6)yp-zLj4aP0HYE}Eb#7M;7PY6HX@=M07o@e)3mI;1ATpY zHuIH$4MQAgj91&tvodSd!ewV?=eEqmz&!%(o^3z`uZt|YFK;scO9hcewzPx&>Ek6B z`qso7;uoyz6gx>@3tqp>y8(3D9{5M)Ml(mo2PnTa+|j4}nzFf)SmVBmuX+=eFHS0$)Sea|~ zjL?r}4>F(@v%aa%z%xtGpE#J6l=HEb7(be^a{6ogw`){Iica{7Dm9T}qk0-L-tcu?b7^T&{M}~?jcCr? zA94m3tcbD3Jq2pru(f%(Vy}QW;D7L_pX1=TtYx0R&p@xg&%iYA*~*jq3sh!JZXA{+ zvuFHBB=2w8XFZ4i1KJsz-<7;PZQv|Q*a2%W!@()p6 z?K(!9AU8^aEvu=~c5B}Ema2$AZ`N(3Yy0VEe8#(Ti`;!a*2(1`451t1#np!w6F$rW zRYY!OKbsq5oMs!U=!AqawzdQ7u)31aZuP9UtNq!-6m9}1!atY_J#z9R+myxF3xMas zl^92)0b4@j{8z|l&l@`z+&KX?_JV6E%cVc?HFCe)FK@WZACCoRjAfWlcg)YnFyjZS z()pLTorhlk383O~*=V%U$pJkyK8topjmdyaICo=O!*oe*gE&k-U?t>+^EWUsU`;Y; z_e!qPoRq20$syMej-v_+A;Q+SGV=%UtgfJp@6JdHCcum}8?5mNo~`}|f9eAG(^R<@ z*H_5H<+iRO6+CqG;m%HPGExCrZil=)z@9V#_7n|FQ~4`htg=dV!I=XN35RX|a(&XL z$*y{d2?^C6=lYcgUYA?&2A9Q8>E#-u@r;5qA&{5`qM%O5-TJ0tKH1qy^XNWtMK_xP zA6`%YSeQpuUKM)lss}vn_%kvW?TsYa9bI7^UF^jja0hK)Cj(prPaPRU)#9JXsdT&o zxl&&!bawfRb6H@z8Xas9|Rd9g%hUf-2>jNT* z+J_wJ5}Gn0?bSv`ZBgN5hoXDY41B^lROE>Sr^^jy)rKZBRFuP#15UiQ!6!1Jxz~j5+@)x5Ayw1 zK`XSO7avDnJwlMRM$GKt`!cLWP4hvVX@acT`^lmAL(Gp6`Z~it14rwVJf|pkTicN> zt7j9q6Wxh@oA{r7H{dK1+p zAL|=#5}9dMga0RQio|Ca4E3|NvHGQ4K){laN1wZ0L=7-8l7i2a^WXk*xxd|!$Gu?w zN|zKrQY_~kG_o*}t<`w`mr8kPPX|%D=ravlt$_Z~_gaE@tjga2O0}{zTk4nBO$Z|d z=uX&zsxDOYW&Q%Dy6b-rnEslWVyR3DSnF2yaLeH&2-dZNeAFn54Y_w-hfrL;bZJ*B zegv>!ws+I{wv{dnqZ&gBp3uP$iaOiU6_gzUA-`u&pO@SDTEaHhM$;?6NjeITT#G=j z)5u(;ppk}lWJA|5VY>`OEyZDSBJU>wF}&bXp^77mAC|j0A!#wBf>&FB;&uxNIP}>S zdH+7C%Kc9%9Syd%RfFGLxz@o7i0j^a55DS^#h`-(hwjS`c`a+MCse|rh#G^ESRo$G zZXiFf^`*ncb_gd$s6D+-g?CMI78>c*nWKBy{hf*l_`{n~Bl9$COErSQa~kd?p}vB= z7o{H&J!;j!RF(%zGWL(`oKa6}iC{>I3WA9tfmFF3hXLz+GNqZQN27M9?Iy1I)p|Pk z1Nonlmcx7UYR9{`uaw6y%JjHQd7LlB5$b(v_JYWldudfN-K8IUi5HX@=YGEkraG9B zXxASa^UTQfw3uwmoNaO1W4Smu+6I}AhW6?Hs!9BCJ;XK;>T+~Fp=V*(7h*S|?zT%w zG*=a}G3~K1?2m*&-WgNTIzh$3pvDS$-G_BFrG1I9+s~6sv_4oU=5WMd+t4zZifeqj9bh$N(! zFpq)m!x2$hNS+)xH`~vk7P*d&8MsVBo+LQo*fdbPuHePsh8(uxj6ya2E~f;TcW z`qEqK3weu#$B%s6uO^sk)5C{mOb$?lG*sCz_-58Nkn#l`#8aE{1(;B$-y<(91{h76 z-H!1M=W7eKW~le~bOzTQzG=vgw2e-+5lUr6Q6Qkd8d> zWmEzFb?B?H-K4PF*LEGv4ogOas0XxNCN*v*M_|J!>@SfvaG6i-WXtI~XX@Z#nln0Y z`kfv<0D@f98pLk&K$F>am@j1CAM8YZ&2j5vJB~HF`$v8+_xwVn3EFIUd2;~NIk@WkKQR;QJF}2?k?5l2*jIp?)ky>9R}kY8VO(dxu$Zhs z_AC#b;TTGA-Y!u5nco8`*1z<>bjirnNV`ZGt=t$j`9s_xkaFh?yI&o7JVh%{43_T@1yPXXHcT{cGk(o)MCTxK*Us~ zjZqDU!ApP%frFhU+5IP#7Gzf9RzokkO>Lj*Ucuoe42$Dg? zU~4-QVIwgHO|v)$cA(ywQrb+t9eT1{TOw{pjOG=zrZH-$VG8&9B0p7ejx?z&@CMgV zy$#l@Wed<>@A>VAyQDz+5YJn!ZWJaa8MfP+t-Lne@$dWlpkv$^2i+Kxn27iL$o2uC4{tf9<(r*1`wPx(MWVURB&&8L>U-p>r8iR#rWefh zXA303{cW#0=|{K>+Hd7k@jklFm@FmZv_l&TM?#AVAJV~23=OObzL$&*%mNp?^6Mm- za==iLNeu07LOxwtR7LBB!Ds81&?VwjAMwA~;_c#=hpmWcK~FGfo(|V)+MkT!`-&VO zq+5J^QZR*Q@koDk0JIbJoL+~2XwLQ(m=#^Uy(&n3|NQw=?S87wdI7l_VG~O513B;J zR_UqBB}n^*V~{p)DCl7Sp2bO9O2iT`yTcW_0GzykkwrMnUHo@pWPIG)p*$d$ZUc_83_n6!*`oOEcducH)2PoPTrkT&RFE_JGz8H`Iie#ip8W=>z?=aXPMywU zH{6;JqZn%kbNcm2Wx5!iM9_* z=eBW{g}xO4lqq4!j-H;L7aAQ(*|Om;fjR9a_QENg?YgWO&r^+X{o-GUX+vLs&KxjtPj!YL-DVyh zvR!QMxXE!`B#qS&SPqv%cR<7%A}ohSeTTwp(FTS{gGh%P+wuJl z0w&hyk1(>lsiT-FM(2&O&DDyf#y9GK1pRa2CzUr~`fL-bP{j77iz&0|)k@$2|BsS? zinnSeZtu8k-*VnhUydC!DfqEnfH}p(gScD&V^M^ec#xmdo`b^P)4mwnrKp){>FKH^ z)cN94Bi!?sm%u^9SoWuUcwM8o@g(IjZII#f4W=h4D!#4Dc!(r(vpoz%Z{?SUj_cWK zOv)6Oo&ejS>T$IuT=2kcSJM9tGEEn5l{fJtF$Dr>bWT(&n&hug5LmQ-&bJ+kS&hUX zIIp&4^VRsYqa*-!7m^bnU$1c59mZ3P1FfW5O7i)?J`+vSUkRt(%c8(UcWVePogNrT z9`220R9VO2>(-r@-|dEJ6Qc0ZXT`pcQ4O_mNezXnSSouuvWcacEoQ0kQOViUJ)9?F zEAzd*R+`8lF7sO78@33QZ`GY|hAE36G&wV>4WjJJ@@hlX6S8p6)K~gE#KtGaTBPPB zI}&vs1dL_r_2)dN>sp(X(6kC|qStBAxS#zCl|%qaBmIT*Sf=yNQ|k;Or8(=A!IVS8 zb9&5LeN?!X>Yb|obWK3t3u|_Csw5W=i{D(S)gF0HCG^hqr@r z7Z8U}`1yZ>%ofc=^)BY?aVG2)es=sX=#u&W-{{hF2#}y20W{a&adKkg9^|SNe;G`q z!(@|{mhS5CefsP6Xl`!)-P0q^HW-YExwSizV!7N{u@XMB0mNYc%~?7Z0XS2$$2r-7 z&!6e(zr#gmN|eM)QHbsrYE0U`w0qr77D(fOm;smb)BU-Wm>3KZuRR(NbEc1vg%!}- zD*@7fdAwC>uvTmTmyIa*7zCt&)Z5Hs@bK__qVR72j$IS>U!JN?f50Gh_xPRINdmAw zjp4nc82&eW2_?NMiHdxs2dolkn@y7&DC?qHwkYZ*^lXrc%s*Ha>Vku*T0fCGqF|i= zhUg%+*Fr&Q9z@w-;j4Haae%q1ZK_jzCg!&=nHpoA;2ew8cGE*EO2Nd-AQ^K&y9uN1&%taZ8#=dC!#UBS4;s(IMg~i!RUbOHJj`OS2=x zKI?yN&&?3+?~{7;JIjw$j_!&94&o!k)IDQ4jS)5>xy z!0IN#bR|af2P}_ISw#$RwsJND46uYWeQlN#t|m8);*t1$!bA)*>K(&;ZzG%u&B&gK zzpD(4zvG~M!)lMNSeJd&q-cAC1BVb&YavIKnK2_tTAbdtI@R8FLYU2b|G#H>nkzpe z_w@uUPwoD^BzaPcx<)P_B~&`q3actiGPaSP5Op5;z?HYNIdC-CG0G|`PWNYes+GEF zsKeBJI$BzJT6N|?YH{DiW>1Y_7tCL~4}gmv0Ll5r-d z7g-rtUic7a?HXfvg4e?B`}1`vVBRU%-jUt+koqT(SmWs7A?51&9xyBQJke3+^aE-v z`jY6nIxY<%fGP#%XsD{<0wH@KQG*s&D-wR^5GH5H9d&{Zj(G7eijVK1D(nXQh~8%iREyA_`tSj^njBblt5}ZTc)u; z`x5kS$6RlfX#-v*b5!?2P*=w_E>s7w;!y8KMxX$RI?ljfNIG(>Y zX6dXlLkW85!D$h8u0dE;Qii5FN%YM=TlIAl$H5=U{LWURsgf;$toca3PAK>*o!_9B z6nc#!I1XaKUssp_7T3v0a;rtz4ZmQWmT z;c(gD;j3E?>cqb`C>9_QRCl%ATI0_90Up0R{X70Cr(dLG$Oj39d!x3s%EV*E;Y#Y5 zGt{oFG>EvXD)SG!Ow9eF+pqk{oUrZvJ>QYfV#)e9)+*S`S6USDw=hqI!$3F$`Bt(7 zN!8}pGk;(b_oDay&y7&wQBnr9bVhcwQ}T4{kL)sLKyRX9>CWq%967Pl?mwMS-zGve zi&R@ZABogoZb_{qq4*2A z?i-+c>Iu#h5C0F>(!;b9HkI@RWG7J}j?F+Agn5dF{1#~1&tN@E=ZT7efkh6E2+hhO z1qRKG47~f8qyB8gjQuPwCe}<95aOj>1B?g$562Uff;yCI-dvTzL_+}pC<;Em0prpt zWRK##{R4fJn1qscTZkDvBV1MibofD)`apuPT^JhSReBgd6SGz_6Y?>TF14}0Kai|Z zFh~WD&G3QxIJUgnRIQS_Tv}RMU|&X^4~?PfL=opzaLyX(^>x`lTs+#V*=3tE$Dtmj z(`Qo^jhL5stv$@zRp-wVhU`44*dof^XIa_4xHI%L(bjO(R|^HQ&C`9h<^Fx93VQ;^ za%_^4M9Tc)o*E@okVq{p5Lk(;?jSg%(P=Yh^$pzyKVPRa-mBAP2cvfIF9tXQLM>DI zuIWZoCkehTdVCda{<7WpLg|jo2`U65BuHKC;@qbyP`efVFcdMQC;=(pM zjqNB1h9~qeOs>&2TltgxA&94`qkHoDxJ`^9l91&p$_Ir3D64w=Gf{WfP2KZq0k z(*0B5Wd=G0f~wko_)=(Ti?*~k2y~@CT{pB{5k<0uzTbV*8a1a2*7kOa z10dSD=W6JeFgLD;MF3A7ZR^m|N6LNaN`GLM-hWAimjLk<8_Iu?dG+@;fFYlreT7U? zhlG4vp!&8jPcf%yP0*|CeHx)H^P;a`NG~I6b7=n6YN{z|p$x8p?H7$ZUc;PhJcTI^ zHuj^)y2VOHT6~(cc^D*%k1i*zHS^T3^FAB&b-a!QWwGuX3hwa4`P!Y^V=axGS5PO) zmlIPha(|>tEI+^uRT-PYz0VFds5XS|qJq+%1_G~~)2gf4d1}HZW8tQqQjgXQ6`I!tX&nq8)6uQAJG#xv_kYEs7sJFcw%XrI7SeZV-Q^I}e>WbYqP2;V zo7$#PHXDb%MTR6B^qvX6!Gy`ZJ_IjC#Ad2Wwc^Ludj=*$XFL`o?)xasNQKVqAJssp zXJn|K?mCOnOTq(tc3b4rp+TkmQ?M70;k1|KNqqJ<|JRf-lAlJt>*~wxOry8551zG? zv5C>JBU2z$tMk+LP6O&@1UX1ayEdZW!vwtSq-- zsHPPw9C$tlS>P+!_$`4V*JU922YZh9uw!p{w_8l)?69xQgP!M=Bl&jvWx^)*X{Az1 zR-hI_y&4_02WTH1wCWPq4p;kg!IH-oN-Rd{wwcbm?q+%9`&l78<~cr~5!kK2wF-{osahL4*KZO6jm7M84N!&Pebk!m%T<4|#wAD; zo$??nvD031!Y%)86yYHIHrYq2%lZtgRIRxrfHId^$**;Qm`&`z`ZN6UyJ}lO6OE^q z%0)Ql3!t2_>W?BUC2Z@K*z69q0$i#>lpDv#vOtJwAzQ~oP@NHLCDJ^k;){#=EQ1#U#2 zvv!ZQb-5;7LLiJ;sDm$Z=T*JQH6tr&qeb1r02S|S$+TvL@BO&^U3)kfVzK?Gl;i$d z?=vOGZLrASY?FAA5m}5NzkJDn)Ib{A=+|}P)VqQFszD=Mht2cx%*lUi=61lA5EvL3ZnPzlOeAE_)#upn+-U|IuR@QfqCjRODM>+<=_)eOSBik#=Ris&ifbwY#scbJ5jWvGu*1?Mp3`{Z^xv zkRG5o04Fn1!d;N+x=+lp2eGRcE2S?-_kGgGU1$JP&>$HVS?OdYuVm?XXdPoV1v&)_ z=esvNPX1gR0llYv%9JmWji!O9^CfnrYFrUcD4VneWyRb_?u$&uEkgA;ugDO3Xiot&q!p z)fgj-f4T;?HR)C{anH6e`-k`-fg6rV3?Tx)2>Ua#n_w*JALW?E9EtnWEqK>_>{)*X zDu^|#HZw10yf36w>pdBE-=jY|GP3A`a@u^KEh8=nLl)~2*f$lGkUl+#4*JVRy(A^o zuao4JZu}f)y_=@8hA->xDqFEeFZ*TDOGaXe%tYG}7L(Rh$wBb$%O8bE^99(SOZ~q+ zh$m}j!x#KLyJDXXU)G{~J>}~@oX->q*j)Q-f(7N%>eAqyJ(@Cg{nu)JlXNWk*S&ca z?G*ei&SolcYy44?FPT2ud5z_NFd~?-h5Ha}7^3j%W5w#85>n*)Fi>W>q`4`mJM6cl z>{`xU&PZ5gC}pF$p7azkR0_`Nly#x?$jhgznhHtFIPvmWc5{% zzSz|p?TM+E;>k;KiN;sK7f+~IIp}e{+m2sUV0T^}Jx18zc zpEakw{Wv3Vc1hrwX#4u*h=@42dF$3`9ccw?o{q zlF45W^jcggB~+w>%mO?s?tpbjfu!C^BKR|aZTs&!Due_-?h8DIId1uGqRdY$_M?jS&v6>a|)n`;Fz=G{ea&+s0ZFz zc{S|gOrwhbD=xjNswGL=f&L5&O$HV+naTR}h+lVzdUYv|Ah1B4ZHBPI0BHVg=REkewVp}a~@LYoY!CJ_n7AaggYmw&b zMc=ruWrqNO8!8oK8T4&YB`Ahes3Nl{jOAFlSlAA%GS~{WdwrE#vV1@D{JBOMYOP%# zewz7WMvhA`CSZ1S*Mf0GsXE=uj}vxI39uzoZFSF9%B`*+i@vomYVBs}9t9SZ9GJ2` zN6;hY<+%&y$AQX^!N{m6Q)qHlXnNQ&0cJQX8nOXRS~eP*Q0}V>NVs4?A>8#d5D{Lx z5nHiEn>*7Z(c9sxd3HdUvE4(ku3>o*zk%%tzA_Vwg1$lJvMA?+^7kue4tRj(i{%8a z?2(gMKAVdjAzs6yB9kUGR2#)YR&A__H*JMs=rSU0l;b&v5&Q@qoL9ZYAC$ipn7-%g z^-lipeml$CScA*Cp6QWXE9S10N9NEwTZ4g*D(*-*zDP&lc9c3yk)Y|X1lRc;;==ev z-#AKXexFs?9E51@v2k;U+UvFsDnXRSRfD{thqF-DC)}`}$3=3XDR~0g6V$6=Fg-tt zN|=!k+{tNoMet!3Wxw&>oJE;K-T#6#wq)qBzrd*C)}cmO`L*JXd@+3?BY!qId}uvJM##M+6wT2;4#@7 zT6N&QITYD4re5v9kusA>&vuP`a|wlYfH4Q16MHU-X^hQJk;|PbH#>!*u1oIu5lg%P zsFW``$QKS}nKai6o+6Oh%Jrd7zup9gLg+jT^S6Ea~ zO5|~}9SkSp3-`>lcPua7e}{MGh+O)q^QS+5jja}=-FkKS`iLW40M@x0y@;J&lMXYq zXV5ap{Se{lh=6PR9Bm#p36NUN|8MAVX*)=ofdjbfTIgkTkIX6zhXb(1yxT>M*Gi>?!Qx-!QR9pEms~;BHmM#fby&Xywy;#YF6}sqXXHs2?;i ziFu+Lng@ZYgd{fi@rcvD%~ZdOiBDaRX1)t}oJBDU_w&D)Ry=m8QZ7o>tC$Q>DpXuRqKax7xUd+Y0!nYK*zm zdlURQ%!8KhX&(&AKECKjne4n)9iJ>Zn^!v5M4LpLTr6jQG%AOBN zbwk;P#FHTWH<&!DIb64}7>;2zi{u8coj zt>e;#_%SDIbvQ$zlW~u!r|Ob-%11~;HuTXMBpznPVPkKMo1I@5MI4l`?bjWDx1Ug% zbca{`vY=bN$^`d0`?Oeopk4zDaOvkzb&q|LfOz2VbzE#p1=KZ#9E*0^-&%@dX30W5 zsy*&u^vwspOW~&=3i^pxD)2*r0N-X=Dl4l3x5lfCyBtcnZO^j>C+l}A`(|Jvdj30C z)(CZ2JUpQ4kZ-B_^=snrrG}|jVo@Cb^%&(iwMs;Ox^&sufPHli%yu;yY z*EO!!Xi*Y92_lRZy^jddOVkNTl!)F1!7vCxM2iwcjp(9;AlfKFh(3B3o#?%cIj`?K z=j?Cq^X+{df6ba}dDpBp>ss$S&+obK`_?fGat}SYb84RZ%oOn^EOPz}<^CQC;%p<= zuzxjcSzMc5{p%~c!zFkx4hxO50?fl%?3+iKSfA-BaVpn)Aw#jntrqGAPn^w6x$2Eop zdL8EYN<3H*rQzY>P1(=9D%zu1-2|X@LHpDdZ9jyD_aTeGO7J6{nB&C2-M}p~L3xCV zZ1&xKQZYR1(*x?>B_cBZ9*7_2huEeXGBUq%|EGn=Of^)ky@b@x(#7Ia9fy6}$<@M0 zecj$vk={}XO(9o`Il5Og-r}mf$=w`XAo&rc&x-W}0m*cc*NeTOqPJZY)TVEv8Baos*dhEi}Dz0Z-0M=4KU#M+75 ztinDgb}$#w_^(_f^7LsPPrbPob27DCO@F+`XcFSbY~36in`U8za+=J^eukD6DtFQ< zN?J8U7F}MCQ+J_q{<3$JvqBhf3h8XQq=fy#t)3YrjwfoN=T0CY=S^r=nwHjfuRY_J zItw|$ZSU*KTH6R2nV(PDao+`W2}Y;QrjLnQcy;(&%pM4UEsVE!LcKBpJ8sf*Ig@QM z0MP8!{`l7$^wC!?-+zY<_LUR-pUq^;M&jESssl5IJsqF_^8UZ<7TJNd0OxEifkMuX zrIt2g40->4xOdnS*ZsTrHP8<8+QSKdcfiYmbLgW3kvUFLb>jc@S^jd&(KcYC=rFJK zpTKO#ZR(; z4Ic_9Nb*%;Z|ecKy;pGbD;b|d%3($O;i54$J3`>Zs|zr}R~yvpFIK&pVvdtH9&6d! zau1c5q<1Fp7S7NdRe=m2+vz$9FE6irFIX=e3=8c2mA~2!-9jJ`Jrf|eM8@wN#%q|Z zO3STZKqKR|8{`LKdNxZy5)l8KF2urBDIA84eZ(a*ElAv(QB79{GD>2}nZz5{;lEC} zp43z{i5UB2C>ABOz4;1@dcASuu4%qzivfSc2vn0fXO!?L&#e4V`e}cA>V()z*Ul^R z6b_vstR}B-1(WUzpE3O%S!dttm9uR{7-O9Bch@*tTPJg~2&+)@N)-7z2i|(c=I%XG za8pJOsNbgqh;AkW?H(`Nu%w?nXUkD|L$TXgydtv?JFhu@u~HZ@P34(k0qRcZt`}V? zq29o)VW}7ds)NnvWt~CEl=7rq%Z!g*fE`ugw;fYn&B2Mfc^;(6MAa_Brfw4o58Iqn z%W&MleOcGW$=;vv#A?-X|Gq6OS9qxSsm_MC?~015s^-He#v$JO9Pr(}o%%E4iq$Hi zxNU(X=coBcKPypj$H!qtxd}-Acz)f)i=!jo@dh#Zy7qM4@G6LQu%P0gVC8RXD#*6~ zh{%U1-4KJJJZ<0>rTzBq(GUXjPK0ZES#09(uRC*Yl?#5(%+Fks)7lrz$7!Er& z?3n>}XPjit@M8hL7c9EK!_w>H9B9#hvAdu$UTTJWPe3Wq6zaaFR|A+a`h2ff54Z!$ zM$OO8n$oS1(sC;RO1%7~T6F(1=BM+$*3qjm;HPE_l6uIl3;6dzoi|6zE|HibP?Y^T zv_BbO7zR$qNekU{DS2SNt!9xXyi}Z9Edbm9kKk2atX$>n&@He?=E#EmjN@uiyicK& zuRkufpSj$d*3Vh)_fe9}s9<^P!PKXKRCj8BNM%Yk{DgF1%crVu(uNK6-BNe1$AAX$ zRGZPnBpUa1N#9W24=2W1aO~*W474J=UD(1c>^0;2q&a>jRNU4jl zq^qeE+G?}C3)O|sn5TLmXE#Sbh+;698KAYXNVn9>H)7uD-DIxunkVUgTBn$bCs=)~ z&SFhw*D(;U9vRXlM@l|O+=o%%Gbog}#`qw=Hx0FfA^SVB9{La+*~4;X@ES?DpcM?*;c`0shRbwfOLu_qmKC^n-=i`k!a-Uv%g0eV1QI zeOcqo@rK;tLOGRZHqgM&h73|YQGveVD_};eNRoEwbp5{5NO|u<{BUW5%Fz(g{BA-2 zt`Lc>P8eqoIWtmqZbCN4zkoyVV8L>arFoIUghGE^gl8zNeWHjXUHjEi4aO{n-hGeH z-BCRHr-MXg;O@09B|;CT%x~%EA)7PdnfQfPItMlS;;ez{=@~J48bq7@o~(Tlg<>{4 zug7@4V;ZDV4EU$MFK)Z@uLV$!XIam*M#NEKqE9Orns9srnom}@r{^myXO1SX6&KoP z>%@)din$Cm_&{t23;56`=^ekuH?`yaNh+Mq%ov1gt{r%raFXVwCXTwW7eyEoCNE6j zNxwox27P^jZxY)Zg?)#bkUaRSxysti;r)U$53JTum9J3E#~B<*KgGpX_5>G_eX=E8 z%9GT$LgfQHiFcpnfru-fQAjmVwDbtYn3kvxrb6oC8HLlMAnlT+P2Lio)mxU$7_ASoJKYu z|MO)3RRg(W4hyP?iv_7<;R31W)ClmqwK;|v z^3k2g+l`i55Yg@n4#KviGkl=;^MdtRDke4^CE~dETfb{c-^)l*hR3NRekZAC!%2yP z<+*7|uWKQR=3r%4wL2fx4qY`cO{{%+4GnCKZ2RXKn@Yas4V2_^TuwT^=f_-6#E-eO z-p!%)D^A77eYfRv@G<*a++j!7#N!pNwE?r3aB)84jwrr#{c@aU$ErxdapdB;wVwQFI_uP+&*WeMEu)=bIvnDt({uWrsO2) zgv_a)#ifF$OQk#Ld`r{ug}I%XUZQ)eTI0uOI+db%uzDumGfUco>!UYYww?m65a|?u zgH$Rq@L(s49o*pfp|+J-m(VOHS|&VCF7x~xX?-T?OWf=JzgP#sOBl_x^bD$T?%duQvK9FhR81(amSUNm;bGqj z%5=YWt$0u4L3)e@{?(uQ6+uHHhK_WnE<>7{H9Ko-$)aX<^bsdlK?TiGI#DU6Z1SEA zW2p_gd*bfnxiEp=Ck(EkuMd2U9cGIKM)>FEJLI97qU1bE%X}+sRXQtcn2(QIN|>IJ z@?u_%Dpgp-I|j-)u^5)m5Ak|x(w#mWISD@SH{PrW!%3;)upx#x2Ze>vBnBccu^+e% zyW@Wb2m4K;!VR~y?a!ZauN(ewtg3zzmGHs>qf*02-|f(x;`cK-bzgH!l)#tA2pVon zjIx-IgH%CC(c)VRMJSD$B~|L#_+5eJ*F$nu*c&H$|4)*2=Lu|5lEeV(ZflidG(U+r z$?#iQ=#sP4|A4H7zkny-B#_=w}aC?<(_mR*UN$xRUTMibVYqrUU4iK?jX4O*yI zH7+|RXSp1?TEH#EH+|p3B;Dk*{k$hY*UuNEY>?Y0&<IG1}<^zPq#3qNF1bRc8{gwXn0C((P*mrU;cb4bonp*FAnMQ#t zci8x>nC!E8@DeMv66iD%^*l-H1XH=$nL^w)zP$Wc}V&P+d-Qm zQJ25tSP0ZwtNj4VTr~;P>1Z*pz%m~W2EV!c+8B#ckI!^;ijIyL0JtxPg5r} z%>#Yg3c)jbrcm9e)!TXG_M9=j7Flv-IU^jDPDEIKV^*ANh{q*vOQ6v7;E zo`3veP}V?vbg;hV8H;pw*>Mt0ePSwE>Gudr%w>Wnt$5ZX;k+}R66ctluhz=$`n~Ig zsJz#$HEsAF4zhRDkE!!*CWIjp*WM2YE5hdN0 zu5_nLjp&r8K+T%*#w#5dPfkw<8Ut{>e#z&DVg)2BG!o(fv^j!5Va<%)(bZB1t-Z_;X@=z-k|1Q1xZYJEXBSG(Nl-|4FY#;^H_}MG0BH2b71S+X@Xza{}P*K1MCxKp1BvZY6!da!y6d z`?RyT@oUh~-zJ`19#Y3M{v&Ajp=RKBm zKc=Q$G{Mc95l53uHSTL_HHZsH<(1SZWXYguW9rav~ap!z0w>btTW| z^@SaZwwd{!1B*1tVFrWvw}z!zB1KM}zO5ER=e;x860x;?{jnp~&0QKE^|c!{RiFC4 z)*hp4%)IWnuh3g;`*kZ@^_ljZSllFInfmbNj8-HMN9aP^OfkdyEWQh*2Yrl}y7n%r z&pUDx8XUcCh>#-D2{?s5X;eIcqc$Zs4|TOxWCxBytWVCi;H8)DR#9pSimuUG=8nRb;4!IZST-tegec~^} zuoSD}?B&^CzBSb1KO?3#YAOhX2TK_>dt;umYg~-f>LFDG+&82KHuJy@H1M*N7QWTv ze&_VwEsZnoI8CK@DvX}xzqP61&w66GfA`KSTy=s9c7Gg(~N z>x!Vu0ZcHVLeJAvAZ4h#p(WN3AEZcT>YyX|C%hi zaK^%Rphot9`C=C{fF!ZQvASkuJC?blwqcI%pA>%ExhMdvvkFhhjEA!|eaClmqdl=Z zmly9ntlG*yi9yWFd6pP936n~Tpt>U2Ohw$lWZ_Cvl#rJG*58As2tPa&=J*w&dvqj8 z+4g9(*g1Ur%cohS%lv~tOPYo=BG-_hl|xu1#kh|`XFGJ~$Jyb8lMP@Bj|on=xVU~3 zUOqJ#Scdu)Qy6O{cb8IVM4O?@HU(aHjsL+C;wI zj91qPo%!Xy>6%MG2k&i5Vf+b#MzIc6jhaZL8$(S!J7ps;k522tLfJ@p)UXvD75vTi ztu3!Pf1ObiSVfJJ7XE`#iI5M3WWZqwQ8Vl2z_0ZE@64ONFUr_&@lEHTmM@mnzA@9P z-E3BXmge-F{z96!lJRuEf4BSg7T<$`4JjABoKLGNRoAH~umUzsnYT}l$@g#Ej={wl zC*>T!1XQGd1E{XgMNQsb>d}2i!cN=Tnv|!8~jMLMf zj|5-EY;Kf4q7`0m^JBE3nVB!(eWvc(Gvek0%#=K{5B|f*6uwTI@fyb}7|u3cmD|Pu z+2IhsW!nsWG2xwSe1SM0o$wg1P&Q#G8E*)6P?h3DaW#}P7^wAXJ=l?K_Wf5vW%4Cb zg~BP}?HWatX91`?MUaeRliYDO2?plez0S_jJb1$n95#Ql~!){7BUpH6__s&PR-C|Dqi5%(V+%oAQ>r_wIgt9a6h8>~_u zZ#9nQS}4*TD{u|}ux6HEn$Jp~E&Ti#dddYrmNCi7`!Wv`J6J#qjnlBX-Wn((q~ZTJ zmugVQa}J^n|EyvMnOmvi~5AJKNt1;;I{=i49 z#AFC6hO`zMs@KC6U-YgbBBB9XSKz$w6^{?gQ@WhM0Ow>!!6@}SS-(^)bQ8&0#< z%B#o8&}^kF>lwn@;po+$M(%D$y%sEO>)`=o4N8NMtUAwphpA)OiRQe;!)Cgx)KKbBARfmdS*ICp18yc1H%$w_A2lrD&zI#!%r7pbiW!Hu{CAfQ|Bob4W z8_fMg?oRcy0&U}wF#AZp_2SwboOvRctnQq>ZDxWt-a0=Lq|aJaf5>Yt z)Ub6GqU+RaHL7P!U*Fyopjf>EjR_y(?h~!8{PH;bVcsEo^nQB%jR%TN zi*M(`-C&}fe0P(yyFMgzCH+1RX%72*L)>18t=%6~O5gL^g0IH>&GKG>3UskKkEtTA zosOteYWDBB^-NoL9wZ?(oa2$CLFaQz(wN1hzEcPl=3xET6eItwz1l?=q3t2D26D7y zUl1p!Oh*MEX&P$nr@odiW~O06|B`1(e=pZHs3+ zMFalat7(RYDoG~&KOmu{To2eC}-MO!XLEi z9|YxEpXE5nrE$R+azDs{YusQIz0QKH~PVt~y%ip(0)UzhO*`7%_s`sjZ+4 zzyBZG%HQc-sQ=kn$a-D(W*LLfJCVpm-)@GZ{_DYYC2emlBA$7fl(+c`Ke)M7c$e+b zJZyTT$b~iH;g2)hF6CiyKkBW*{heNh-eQHD{|1Z&ya3P(QB(0F_McApFF@`G!xiNw osNJMenEe0ZxqsPrRd5@FMY}ZqvB=Tr78dwZQPfl@mNyOjFWO=sWdHyG literal 0 HcmV?d00001 diff --git a/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg b/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg new file mode 100644 index 0000000000..aaa386ac8e --- /dev/null +++ b/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg @@ -0,0 +1 @@ +0102030405060708090100NVFP4MXFP4MXFP8S8F8F16BF16TF32% throughput of theoretical peak @ powerCUTLASS 3.8 + CUDA 12.8 Blackwell SM100 GEMM Performance16384x17920x16384 shape matmul on uniform random ints range [-4,4] diff --git a/media/images/narrow_precison_multiple_block_sf_layout.png b/media/images/narrow_precison_multiple_block_sf_layout.png new file mode 100644 index 0000000000000000000000000000000000000000..a1999f5bc66715db33f85d6d5cb3d3365b329ac8 GIT binary patch literal 37653 zcmc$m2UwHKy6<(l47R|sfrPfKSSTtbbe09NP;7v76zLFp4MmNLf`E#EbOh0L%scPQH|?AMZ{FwEHB`YnAUif~ z+62CG`Qpt@o3{S8Y13~V+knqaAALt(17F)6F6%mP+Qj*Z{om%WQ=EL8HvPHj%EfcH zJ$ffdiuU_@dXYYdly z-LH0U=ILqSc>6)#YpPA?t}@BaeK8}V*g897<`fzoY=UjW8M?Zu7lz2;o3dGAv+=A| z3i=TA5X^jD$M1&UJkQs!kT2o(8mkt6j^{EHnd9Q?Uy^q|VPB@9R{~cu(d%CwCcBj{ z{P-D2rHN2k-SzzI&za~$K%$*s?rDklf7=)rU$^qNjZAbK+sd~h*1yiPt)zb~V!fpQ ztC5y0=6LS5zx|WE^NYKHH#L5sf9BqLC0Ic?GV9sbztywtU~Wys#WkF9r}3H!`-`PB z4q}Xa`VWjKv#WUIZ$}cD8yjln)FYrh{vYdzefb|X6tNj7I3VHsBl^cW=K|jR(l8s{ z`;VU^yP^Jj9%B>kGue-|MxiM)mnM<@u&t46LqfnmV9`~azjP9)3Ia$N62K@iAF1Y* zj+xc4vu8rvxqt2`<9?eVtF(S}5GC`8eJ!Hq=jJka4_a*RQf9ux3|d_N`ar)v*PM4C z)~ssBCDsxn-fzo6!)BhG7H*AgHhy4@2G{U5|D=4`;`oDqj|HCjR7`AR)EuGMjN6iK z@0{A&NWQ0~9Ld4p-LrAfKw?VBFWe^5ck=8CsQQa98N-zLpIyP6LRD@6*m7z0Y{<%ZeraPVeP^sMr>87uX1e^|qQ&_RSOxUL)B1t*$L9A67h> zmW@Jr%nH&J;qmB9VL7s-zBruubjZ+SzMO})nj3%65B%)==*>hMgoBLh$}lw*63K>s z@2p}8e5_YPj|U>Ra=ceQsI${1`hrT|Nqs~4{w%sEAs4k)(O9m$x@+0yuMMr~w+}4V zD>i0r?a)+GCvcKSF;QoBPyheGZ%G6vZDLlI;;(Epu2XEZS!q?eCBTTGA3B&$|071c3%wr-<0@z zCF2d5+f3?vR(NOd4>rsNthW&2groGaxbS-V{AaRY-LSyt<74(mC9F-ubv@Nt0y*QY zv5Oi#<3uNUa0U1u9-0AINh)d1=?coNiPvm9Www(%elXG~()D_XdZffk^N4lGrA;pxz8 z&!F)7t2w)*^MZVgApCxgALS^t2 ztBbJW2Uw@EwGBca&o&J)BjN;Hjz%<|Mh;$NBrdLBVdvdd*Xy~1j-b?840ZZ)80u_U%YCUQ6bJ()(FO2h2jzcLwuW;M)J zHBY|I2?j0BlG|Gs8+Qf3Pzf`cOz7;1ZQzBCORAHKl+5)o4PKb;LzyY{h*Zp)Ud4RZ zD3NeBSe}$L8ZxRH;}42gEFWDeC>Zar9KT)6Z1LT89_AAQGli(S;`z8?4P3SN1J(hE#cj?|Ryz}<-*uut;Qo#a84 zj-YqO^kQFbTpJ^(?=XjapKfrjGaeMUw7xKY-%Ij9M1EO;E%TnmAN=t%utu&M*|_vQ zVH>%*VMF?}5gdE!U$e?v!m?L^_08J)7f?29*U!!X8@T_%u62F1{n*g~!}!}s{%_f} z4q4LuF2W38&)GZ9`rJ2Xjp>^%z>LgUysL|KtgXK4=8G&Alj8TlICPoDVlA@&r)Q0& z%wn+C|H*P%qB3%^7LPb%@C3Ez{jpOlLw$Jl6M{h9|scFf6UwVJq7UN?(Zpp z9}j&`0c2PCYGXZn;8z>#*|~o%;2fgtkFP1!5KsOt;xw51XS)%t4S{ePj1+gq$I;_e z-Di^*HTVd=)Wrd26OzYM1Uzz{r@Emvw$QjYmK3PKExB}(zP#KurdYYSTBjno6C&dD zM9O1kgtCAdBN0;3j?Pnk(J6P6+{ z$uuyYh)qJgl3F5#TJU(++n zY8-jb6zYidhE>#Y#`!HNoLK5oLiZ}ZB}HQVr(99n!Q5pR!Ac5!RS~NL9xmRqk{&ZU zOMbiiNNdz#opy2?i?8{fj8h^JtLIk--qj84W36e2(pyiCFDg&Tt( zKG*%h!Da2w;(w1(F5g8NON~y&gpyFUeig4EtIN2K05vyDho#PhRqe=TqLgE2v#H;k zI0vTTRfLMzHH6JgN!li%mhJja7WFzu=WiZH_8T}=l;R_^VkI1dpw|)o{91g>K-x0jWS^!Cav3Xz4oJz8rAkhKqK5s+mM%4aL76opt3Ss5CkV;13 zH2LraM7)3&$4-%{(5vo9)Qe(a&O-QvuPgDK?ppLorHcp~FA2cu=IwkieWB>=5}j*7 zM5e=cJD;Tb?tq!o`P1a2zYx%pg?mMDp7YPE`!3CbqC;Wzcyw-1Hew|RsD4fq2<=YvYy`XPrPOLiwJnhVx~HI9`QQ(vnV=Y zt<~7QZnY8X|~GaDlA`3M;DAf{^8Y1+7`*ux=00m=2h&)sfA|}PPtVe zo#+o#e}UKtzttldvBDe=N2SqEzVJx;%)rg&W)>~>ziVNvJmT#{3xV-u9nD*ZBqN~# zi+Wp4UCEbyIyXazLuyL=BBU`B#&>)03>%)7F7nK~%@)RO_ePzD?mIzd-ff}T(R{6y z^4^lta(xf%6Lm_Kl(F0zk%2GnBI6DgW(v-q&nW`s);&~RL183b^^%EJYsF%RA z0LL5jZ=z&k_%xN(@7)TP=bOp2;8qmUYBPCtt(~lOtH=PJQwYk9N({_2+JE;k=7Wwd z0-sGDWpH_jI9Wj|tpobR;$l2wrIwn8=ovaH2ksio0p1rT+O_JaXYx)YT=F_C_-=6* z;JLDVKzdV97g@{H*%4PJ?Yc#BY2qW#DLtHsWHyGX?J!4NQM!3$Y4)L^3V*5-{u2pIFgnG9q%7 z1?>&y_IHu9XkI>?Jj~+HxOg3FTWQ{NZRWFbqBZ zg&I-^{ECzEptKg9MabPk6i$WCFY}Q0V44Sr!E-%~4zjOBK_!q-IS4AYl3_!iU7jOh z?W#BDJOCk;o*r-r{A}{-R3fHqeE*ARyVPZ+P^hx8LR4S2$Q3C@T@ugq{_JB-d*$KC z>LH%YqZL!?#+759$8gu~T%)aI&vbo?Yi<=b&>rt|VYDr6#AFywU);6s{-<0y!GR{dn!<|^4-=~$D3l~EN(_;l)uiI z*Z;CSfb3a*%*PU0rR&EvF33oyLIwh%rY)22oDrr&k-1Mmlnfo8SAC$C#bxFMr4=!@ z))=M6HfhW@IoY~qCl%{T(Zx!nt7{ov!LR$tRFaQ(uu6NP?~A?FBUm?b`>^-)+Ukf~ z9LPQR*_2N0S+G&;_&9%UC)r?rhp=-A%MZaNS}~*GSB|`)J2+|^8ohK$i6hLNab^*T zbLpZ8dBt6s-_tdS^vzVLi*9bESmjM)KZ*)_o+J!)$5=lW@2j_m8CQWghKfROwQ+ch z0SiRT(Sk=(p)eto68|~GDZ@|W<)1_^zOnscJL_%(qZghoDv+=;a$0KNYB(K=4@r4t z6S@C%Hbya|CH)d%M|aqWTS~uq?~In-AsgH_L&?Ye!3mhK@QK#K?{CY)%US^R`T>mp`Y+`VXK@wsxCp0rahc-jtLV!l<) z#sO>UoY5}FKvtCbj6sOgf11DHlm*)7eOO{9fud2d$nP~f>!Ob_{(5i4ONbScKO zc>gUsc%1Yp{lNaDq&!D2sKWbUwZX%Qg_M>Og{^PrXJiScR_8`P@4sbeS2LGauB#FKI=hB8aKtu7TO?=Q zCN*TsXT#|XIL*FGgS_B-uOYv$&Ri?$G6H%`9C04CI@_f7QTz%5v3E+GlD=Mo8r$cE z?>-O3NSu$zc{^@>Dy5fO-p;(38_l75cWBPn<8njfC)JA1yi{r9dhwkhy{o#NsBFgF}!87_#!S}P9(*OkT#avJG&InuD*f= z8yVWJT_}u7#vGWb2niO(J?wYyxKO~suYWtJK4COE3u7lgIz_Sl5G`@oof)-&!5a4q z`+w|mkIJqbp(2^i2pt}MR$G>1J7A#2KON<#C+YDaUjvFfBw0HUILqO(eDsm#`dD;|4}F%RQ<0wSmEmLrZ=qaam0>ai zv;8c)$b$&kEaL07LY8IL4%gO0<3ms;A;#V3)R z2b6ZHzQv1M*8|lZMXSw8b&;R&m+b1wrp?GL?9ojJ*=DQ8B1P?LBeMwaz$(oLcu?A3 zy7Gzk)w@UL10ha8<<2hF1%$3Kfq}ZTEoJf9s0P@`L!-%rv?7*kD*h_W`a_pj=U9$8 zCQ_0T5q_Mg7p9N0f%4xqLZ=<9|>VoO3LG)zbL|$0JRjb zf1Y)FOG#If2A>`#GD|;d-2+no+l~-Tf@7EfyZFWfnuvN3ChmqA3dumPB)pUYqnzdRt-+Y?8FV@5V`l_7placrqm*Jf32~eT-KV-YKpoW=c(MO(`>XW zQ^U%^)@7}uwHO18dt4iN{s0M1`@F@mNowi7lgG?T%XGjeTby${`O~w}-kqO2!fstV zj?A2RX{!2}9IXf~)w)r2+pM(4xlzia&)*w8YgW^^B6b~z!2pC2wNS$cLv z)LrhQix||70+j;zrCB3+f zJo2XVDbJThlYUlDzTM{_0G4sGr&R4CgC1>Dy z9B6mjXqPh({!&Y`b7K$Y^;*(zC1Nf}$vDeWgQwPHt=Ci90V#9c^ zc*=O^WvW|s1ZuhRhAaRx%m9hXFDtot<``>S{KxKMAhpeVgu#zPo z2Y@VSn+#U9+X*7Hx$?$qz$?*vkTADe<4bI#V9Bg&DwnuDFz*`Q5S*u?Mj*uE+pg`k zH}F>xdyU>4j1Noc_a==YBcM>3J3%KxR)E9r8Er3u$6PA{|PV49n2M3hy&VtK$>?(x2M{zYHBW|DVmR1(tWc% zu}*RcIN7)FH-U)EP^}zfDi`RlBkE5YU&@?rq@pfr%X~s%Wmgj+{5d=47k>N_U3F30 z^}0)~H^~Vs=-fcfOu!8M(L=gsbSTrCH+yM80ZrVaJ3wDjqVxm5DnW=|`y<|C20pb4 zMO0qYUOuZAq@#D}l>2oT>rKw@Kzo2DUyMDy88z)X!W??-8Y*Xg-))uvz$MI>OTf@9u;FzYP?!=Ur|=SrqGV2I$)xjCSB+aLizepy4Amk=Dp53%s>t zx@XF=CaB%rewDNH`V8y7bS(^rH68g^9v1U(FQjtnRjcq+iu4~Eus`?l9p@9eZm2J3 zlK^s$yE4$HaU*h7v-+}|mn&9=F@JoiO$rS)skJ9$?-7&_-xBg` zUGugpr~mAiTV1SoTJBQ;t)7Fxdw}UY%&axOliZ_IR*YO5n0@bal-?OIm`N~Ue)yug zYMW$iEK^;X(-BskR0Es@M|6U7)O*X|_8OF{t;IQG2JdN4~)-67wQ-kh^oXLzEVs@9eIN>)f}3RfHu z>#O}|K3`}=lPO<3j@A9z`qSc}9X#8keVQ`u60MEvF1cvao6T()*Rs)=MzpfGkZ}@Y z4iwm7&oyuxTDow`ht(R>8Wv+?*zx5is$)E`nv*ieM8pS9KJrs)@L&}%E$Sk(G?U;; z*8V#9lWrw{O|Nv&_$Ng|?7X#Erq?ps^Lg5vFWbN`Zgm3m$#{(=qk*&;aAII(iBEe(DalKv@{pc)KkK-|F$1~2$ zt@Go0H5$sNPUmMMC_gH8JNziTFD15EF|n)v{Fi<|@BR8Fwr>7eM%s*%$Xe$=$^sdf28jm!z7u{YZHw+L@>I-whgdtXOzS$xj|`J~TUqamTLq9;2LV9p2vK`VOQ0^1ZmrjEQgYZdf!w~@`tP(FgZ8! z%UaqEa|e^C!iIdR{Z~VxU^Z`|nTu%2JHsc%UbzDMo7^Ppy&hhRK&ONXhsGQ%UHq8j zTLtsN8xPp@#a206MvDTodR4UI6)Cr4xe=V6i@8IBJ~Le!OM`_e1N%`*3h5`Qb2ZVQ zJi`TmH=oXgQSx+MplZi>#OJf8nR249<`<|8WntVgdZqYcfzKv8(`{3iJURy$;hT|P zOE`WRD8bs$9Q!mVw!ylTHWF{_HP@goKM;uR<;ophBjjqCk^IU+Wv~UnyU1>VeJFk^ zbLQ@3MVz~3G(p$wg+M*gCd0CB#8-a1yXdQ8xD%amqj{&2we$KQ+#)Hz3=S^j25Yp~WA7ZI)=_jMZlt>`EJ&Ad#v!>L z*)26_dvw<6h$76Z@s;vElUtuRU6*`O``FE!4(FK)*V1#OY`OIRV;p|%&H}%E??^w5?|+a%|DB@@ zXp!}4hb{O1%Te|}qYGPo8$17d4KkXGVeKw+EIWFL{fEBU=zR`)#|5#OWbzmEKpPme zE8ySUH#UKR0unF=+rCaLKMj_jyYDY5ysL&05$;Z1crg}O{Y?#@|56L6;jEcSwi@oe zU#6$bd4aF05dNy}5y*I}l5)Qs%6~DARa$R)4xs`9^yC)oB)p>nspmq6M)~0gc90{C8eifkyssK zzIbU#-dou*UO|3*HTT#slP28Gm_6sw+R@=Gcnu*eqy)2{wENwnARlbFzinmV@ zSX(Ns>dr>!u-%>H0^IH4DE7f~V3ob&-UqJ~p>RHnu!&=|aG`kPNkQA&eHOI2q@CjL zoX)aUKFhrNKfy{_+pg>k3AT>|X)<@P4*9t+=7s{NBEk}(YToU!AzMCLaDdIvDl{ko z0*bXztjxmje9~CoyPELNJky_(iTo z27lb@T;wi4zC=t#o#xaL@o~&FkcEcUbe^C-wW2TfhfE1qjYi&?GFO(jgiIjf`zywY z>OlDCt^Y)pBw|yzhc&VShrEMmsiKf`m%>r(T!6N z(O!D;3jlx0n}CtZk926gKY@hR~1GU02G$i;g*o8CnC=t+546!bXVo^QToYD0cwcMOLD0G)kGfU_ct=V-EQ0bDGm!8TXW&HfOf}P z@w8KRy9O%xtW74;#L_WRD4>p9Bes*tQbRZZGEn{j8E7WM9)WaP{OpsKRj|(`YY~k5 zrZ(c2CYmz`%U4HJ;e4_}q!YuEckVv}aRj%Y{aL98o~)q7jJfPhn`5_;sxhzN$x_m2b(u4&8aZj|#sR~!qm50N>%c-4kpaGeptTAj(3yW~|^c&6$zIqt%i^Bg(x zAI8n$tP;D^Zk?{^Dy~mT%*j;AB{F7Z5#6(WRcQlPu>+u6BrHfOl7uI+(VF+osE?U8cJ5(vTDS-}Cc zWgdt_z3f*>-pYmOWR~L6WYZNp{45uQy;XimJ2(CuTZX3)E?pYQHU48#xK2vOayslM zGN2ShQ?29DF~5pLnFEJdaq|0=V0YZ}4zKl*Wi@Osx*H%6+79VnhBja$fwL9kH^hel zDH=E(dR1%n(J9lg=Q365kFmk$T~t#(E(1D$Sa(zSi+<^uuk&g!{9;-j6wEJDA)i{`Rj&)0;_5v{LTEZ&Vy7X z-FUyjz&k#TJvNs*$$})Z@ht#BMR~79OanZDDVWoo6N$U~476C)QouRkVhbVq;+pV6 zA%Mhx$RZ?ua`BAo0%-iAnv2!EC|4O~MwS z3xH+JSc%x5B@)t^e*ug9@;nwrAc%|WWFEJhKobG7p!h_=sUCf_R@{(JbgX9S#SU+) zw2l_va;=iOd|BK!*cBS7Y`doCuosZb}26ue+ycwz~0+YyY4Dl(c!Xc)JR;ox{_^ z!}Gf=rlu9l6xJK&5UPcRpMGa{DutrC7*Nun+rpy=M+(Kq@UN);q& z(BwihF<~0y=-~wg$K~|7UD264#2=?ls6|Ea`?{RdQK<67dK2P`K#K~%xuV@ElL#r~ zUPPjV&|a4wC&Jb$-lB`z4v3>-()ky)^O$Xu4Gbh(*guSv7-P%v&{M2Pw7zp;%#+DC zu~%kaSBHpM>eM1vd*1#445(LctF%ly^X(&{n_|bXh9xm~4xM_t+lw)3;{H20r@pyx zSR3k?tvA-PN{Wh}6A{V$;oRyl86#QljD_hT;DP}DQL;bpS`?vy<4~k8AvIUtMB(E$ zL}E4XJ7l3-*4P{D)oTmY=jz1dcRx&@TZ=liMf`sH4^s+N(^&7O_zm@xM9K?2iy`dM z#i+lQ87LV9_?|P8r6tc>i{RLIP-RW`+j{QqSm{>T!!Bb1={Y6$GVEMC%(ZQ%;V%26 zHE2)$-^5(;zUZ<*#4n8~o=bv9y%i!by0ffz;mG659BGb;E`>~ke85~YqtvDOjyURk}6M=kaU?T%ICDE@Nd1Jiu+cNk50Bz-|iIuimMTgl|%WOX2wXFPfLEq{qOHvcUi ziblFLdoH{f3{E`&AOtqiqb@Yrxe3Zx@Kx&6R8Os}#gB4C-Flk=p`R=GW9`@){pD+k z=ONEp;~2d&`^dJ~cM0u^sQAO7PXL6#^a03^D_*Mq|P z0ya|NP9@uhByW9pO-Id@4OK+PNN4K>_IDui9QDrRuUpSi=Dh4)qYq_nhmhdUkJfn! zMQa6ZhA3>0*Zw5pPI{r}A-HdOc`7yX1CO2NezA`IzAn~k@Y1tgtR&2#ukw3G4A{fj ze7M!PLupBtcrVp1DW#Va8HSHh19$??l5trPdyl7tS%s^UvYm;TEy1E0m$2vB(+8%L zEFE0UvVH@@`VtEn&PsU-j2?BTesZ{g%i>R9KxQ-m23X4zwT*zI_3x6t_RyIF7pj&| z7gMeWg1BT#(jbZ)D(;*i_HOTNP>cS?@Ycd9sAKiRnte zTZxm)tmrHAD04Fk=a}ZyI#yxS3aXTqaH8Z|39SMgrXHFgZQYygs$Pxvj+=|;V~Y_(nYe-zT0jqI5ajOOCR!kIS> zc>_U#FS3EFAC|mX`u&D(6Gsbfg1hxXsb95|&S+;?yvx^;WYX}^vuC;_;F zFG2F6q{pl))kEsMuc@*Uw#*+O?GBH2k_+R8L*+%CwN2e5{{fb0`rD&j(Ik@G9|`^- zi_(<`KCgG^xKr2jE@%fPewwiSj!igE=B!FQ-=Wo2IvVq^gDpUoWI?KdO;bgIvXZ9n z?PzSFUsk{LSOYG*FsVZ*YI^t)&|eb&<`$Bpzi|ssQa`FAPS(n>xrKdubINc`<^=$Q z6;Jy@R{(RSZWcclsd>%o<3>K=8*oslt#^`GnR*Y#1`aNrxf_k7G5iyaSls8q!9g$m zx_h^y^I#7CyO85w0QtW5P5Xqz^WbN#dIRYW9SH%^VA+y9NKK$9UOUI=MNicO7(GGT zp;Y6JyLy47%)#nxFVAz3nEXF)Eks^ay@*#X(#k9dZx0Z020HUc>o+%z%e>B1mMy9O zb|b`RXBcn&?DIjmtl|JJf-uMej{OxRJrpl3_EpUNMRWoyMe@g0`Z`x0Ix{q9ASu%x znzi$;Q@K;0NZ0k{r~L3%NQ@w-Vl9!}6_$TByxvPdV~Y{EQ#nh@pOL*!;8v2j7B6hN z)(IT9xTl}yt^4ZjvZoj%=D=>!fvV`JoR}zQ0vlscv;emr;51xODz<6%`u&d=+l`Hf zYMQ|?+>4=%T}ThpGElsCzoM*%+jlq494Z&nX&7rfVP~TQNON&Wkj0BerI0erg@r4P zLY#(K80wjFeF21F6Tn37B_8vonraf6GXD(J01eLvjkbyI1Cb zEi0*uANB_2x;_?mLx?f2d-L9`3feNVTJIy56(nUG+8ZBz-rP|ko9f~$DdVg=b`U~r ze<=iVdg{LGSq=DEnrUVhmA?TD5x5udUwro;LBtSdWo%^mcc^f^8LwsI3F(n zmdIba{?uDLDhL2H`i=UC*0o;mJcOk3Hf$k=lihUTsf#_fB%&&BtGS^{e7`*{x zfVcx0+ySD4UO3p#>YZnPgv7{)Ky5nGH#3e1D2yU+-|+(nm#c?+YcJ%*ay?md^mIlb z)&geIjK#rI?wR)|FXGByyG`qG^R=q4UmmDFSm6H+nwZY^$zko$O}?hdkt?piQi4>L zcQ|pFx+|KM9szUuXH=nvj}|mWv|kmOcbSZ+mkUWwMJ#3+ns#-MXoG{%7TP2+T?AmS z;s=D$*LJkP3BW&gInS*oaj9PAu7l(mEF86RC>84r`J3i|+C0z7my|iMk0h5T>sP<> z2N-520{}LgAu;rLjr@r>$eo>J+_TXsb@wEV{ky7yGE66=E(8ncBQiO?ZA z@kh(mXTWhbST^1PZyA{2-%L1aFZ`VfiSF=y;iZS)*Hc*&-V`Em1-QsTLU;6q+XwU6 zdmoM!B@&mb6&J4T5Gch@831-pHP&`$hQ=%xf&It$CwhPI9tOr&sg@D#$=v#aP7{|Q zs{|h4zpte|$k( zx@RvGAAkMv&U5{Yo|2n$jd<`Rm|lOB|aB=|t{0Wl(VU<~(e>c1#sqd`N6^9E9(B z(o1~XQTEz7+eVIOtmf$j9FM<1>spJyKfr5PxRKVFW^34B3ElWhBMRLyRyi1pt=z~n zo@M`7)GZX<;bDGHd``*R44*K*Y2M1>qmAKkZM$PS6j+v}&&ldgoK6h1S{4V+S!4F2 zCDz#urrh7`Moga{on6J|oWI$Pdj2PTi(lD|KkDUG<$tgnpBDFKi0*N-=p`OIIhhY! z#kj?*q!%XV+yz@!x)OG5d8~z3cp)ww=3r<2rM8O=r0_9_F=Dt1osF1C-GwZU zyJp|B*tw4*%AM}*X**Q7BBXya2zw(rOTE}o99@ZQmqke{xViZkW#PQtw!fy;aiFd6N=!`&TZ6_C|R>z=c!<#`+$+0iN=kkb%T3iM?); zQZ27CLtX_qCnM|}Q zcaD5<8z40L0mfu+A;6d_H}{(zt;4et=@waYJZuuG%E?CUwLp3k$YL-7(FX_8iywGTY@8pmWW zIMAM`u6dJOQqd=aDmvUbYqC@jI&N=7M{eEKhrbBKMm^P;TcN+t$m-y+-$T)L z1Ayv5L}H+uhQ>q(e08wBXyr-Ct_<-mvg(-o+kO?oVv2ia>_BBj5wN1I>ow`p4}o$c zCJ`olu)^Jfu1i&BZMI8M?x{5XMy){Lf|Gww;ayXU0!Tr{G9^NAr&6)oN1pg%A)|H% z(m1oRoaH3!7-t56;D?otms?VXOYYL;{l$d<*1}V#tZct_VqS6d-FAH??|A1W;2=(D zO2__vIA+x+yuMda_U-nY$yelQ*?;3DomfPQ{F-3Z2x&~-|8++qCbP?I!M3|Ql3zq% zhDobane=f6o2Z!>}HVnZk!ppL%~|4WUt?#n-w!IzSymXr@c#lEsJ0FB^#0xUMWm1NTX z9VyraTz~u@z%2f=SGoU9PwX4Euzm-99q{;HXBgId_BzJ(Paob_`r{k$^nlG@e5=N~ z^#9*|6MwxszvX=R{hs|#k#fuf*y#kmU8A5aAeDb}-&pT_Z6;#>Rp#)8oUr&5s=^gsJ6ALC zl`UCXaT=Nv?DJb*I25uKT(u)+f+))X_xGUKY)S5K@$vStebsQRpTTfG5HkX?F)qW! zjd2-z+kcJAz&QZ$B=lrufG3Hiup>0|l@?cNrMOGbLV^3dAt~;NkP4}aI_CK5C%S-! z5ev68$A(g3)Bw@}8CzY)$B4$Pq091FO9KPIegZ%^P>Mi+gxcY?ddTY3bgPP>{9^zy zc_h9z5Mi&2W2`W-4HKN(r+{0&K@Cl@xee85UxF2Xm}f|=&tS!fFmNphQ0cs#EWDIU zS+xF0B~SXS$>OkiZ~GKGqJ*s60bNADqNm%dI&-JhVGR#fkuz&C%ii0u`i&=J^(DuE zqBP2nVoA{*0lUqt^|VjK?ebfuS&-s&AqvZMN4Sms;k4@7uD*qikX&EyDgclf8d|xB zBC+bSr5bzcO+-3UD7v1T|#Vm z*mYcBkGoX_2%Y`x`c5u1Q31?cusCi?lV+{1067`{IGWc`-Zcce%CEGa0^1%hFZs5C zj_ivsJ>$u8djO)j^|T{~fEW(_e(?(qXlgd*-7?ooxhALJZI_~3KsX5BH!Gv9b$>KS z#5d3v53;F~;2+e98=9H}q+eu}E2F*2b-b4Dk2F8OV1Ax)cv{_+pAi5G-1V6p)!i)r zLDl)=Q~hWN%E)?nA^IAOXHhP+*x7unUmZtI={hz}A4!wDV*(I9QGJt2l?yRut|3}V za-O((n$7YUCa6~vVdHd`wn=m3v1o;~loz_QzE1uJwbF%-P6KW!O{QHy!J_#>SHK+^ z82Qka@T3K(gi`s*aUeJVAs7`qN)eVSJ$7K0linVaH|dtDBEcCBSsv~=iVTFuHuTEw z#d{#6mYA_tFM>7u*LplGKfl8Op&*bN%A;;>YwQRSCq$;O960yOK0a5z#-;E+Eo?hy z=VQA-H4>F_VpRWR%7e|7%ydv)3EgsOKujcXj*??H!LrZ9x!0~_p=hvc>`I^5%9Lf^ z<(mkExk=NT&W6ex3U>Qn;2s{Hzo9y6!5z~Oos=vuVAn6`cxei_?qni8W2|Zj@vanb zC57V8*`St1A1~fnyP;xs0;7F+L+ z*@o7w!z(CBnYTf!lp_}q7dZO-RxO(7V?E%GuD5G5GUMgO{Tz17;r0&m>cxs9OWdng zs{~CeItKBk**hK@keWS-SvSgI=@25}S+ZIdy&X5X+hW*gFLG^Xxk;Cm1~oY$-H2$qd5M=;1t+Md8dpKiv|Qes(Yzw^hYTSXD(06Z{^8q&88#; z0MQ~^KKz8H0O1hk3m=G6S+Q;8V&`zhvv)$2X;X_`(dAU-fEUF^=s&$i092$G@e?Wn z#4@|f@;+EZjZVFVNNsm)3ghh+&DC{({f$ex?+S1!3no8tDVJ>M<2~mF4kL7*K;wL) zKBb&W&jC6KP0gK}X@C_fj%cC96f;`$6+c)&u=FpwwKic+1Ok5Ct>R=;wJc@Nvr2$eY5ijlic_bb&&J8dtm6}~^WjXr8 zpaje$n^Q>zIF)__;C6QAaE9%&=`)iNhrtR!9R@U{Or!ZHPxvJJf#@n0KG88>jzOnM z(ZMIbk|ujTe*N7gjm42$Mt_G^s3Pzfl zI8MNIqHS+`bQr4{h-{(T`qAyz$QC{7--AgYkg!ypn}~HNMiYQy?mf>kF&~*k=qO)m zBY#CsI2B)k$XD%0fVdPR76aFQ#HEOxo_w?+E=75O7HJp5hF<2H(IoTz7P&N_QWQch-Mfg$y3DVu_kOo>& zTJ_?%Fnmz=TlaqlL9*jy>I{Pik-)BDR8aO6GeH8(1YF|oKbNLI@ye}yLa~!%x4_MX5*K=isJ}BOaq)mTVdik~ z4Mq>!MxpW(fuq~3sCQ`Sm#bmk;&Pj7PswYD>hflP0dV`vdJfuN>}RW zPPK_hh^$k4A46WqAdxXGT<+xD!d>xKUS=u{@G^txa|yPZHS{FDs*2u~Kr3lMq{Ika z=nai&+fLlTexhf8%{3>$OsOFFGqLOK8hB+#?DB&?1Ct+7XGp572at?cM*vVv-69Fa zN>9kznD#^|=gka)=XV_Rrsvg!0ZS2;NL_iS(z)0OqD8DJm}BH;I*QmDj10j_ypy+= z<592%;A%5_?wk>XgmU_PhVUKG*u(}jG#mh+0i+kOv76sH5!Jqr24K#|9e2vVDHpVA zW*qDF+^M{<@Q_0)rPmjV&8lg+?H&g#fPY;!Q|e!al~Kh ztA{Ce;u#Uw$>j)I-U_DPr^xS=*SK!FwC4~X@LG^6t3tuX@T?8Sr0e4dwK%ZL(%t6# z4#@_UeBE-l*%`+_D@!aBZt`dCkdpJ=)M@{6Vu#c>GsdGo%@`gBwzXN}xe>>1HIz++ z*2nw%_J^{Uk%F_r-IhiA2?2h@s~&Xkr!G!j{e?L+Do;|N#&Pn#kKh?%1sftM_sjnq zL~{A{vh~gv?@si!6rg+Ufe1n`Kd#(}dONz$l)bhnmfi3<{ij?Qe{msJX{P%a zJ%q7jGAan;Pv}^yS-h!EZ+fV~@o$uOp0<|$v8I00b*klrsl49ab$>?tXx`C5k8J1& z@L<%gRNbx%w0CaYicpd&TfE9)W6G<$K`mAyb^1XXRp*cSCupluF>g*cqb{@*acMNJ zqcN^8XUChv>^6GQ!WQ#uZoW7>=!QRI7UDA4U)e4FvD*HmN6$|WY9fJq^-?0_sAZY1 zxcDWcWDi%Ha^%G0iJZk}+hn0x`-Y!cY|9SU&e$8LRO>j~cQM_kJlw1OHKb7$3!gIK zQ&dClIz5mT&HDzupS06D>9Xqu*PhshhkPuZ9H*v&huuvgI>F&!)!yOxYF&>bO%6|V zhRxoJSoXiHDfdI!VENnYWv{3`jHBy(La@v3v#1}dqPV_-1m;fCNEQD8X2}b%u~(C{ z-=Ev0rNyCOST8H?J77ik+p!z$^iat&Gnep@I9v(Ta;wK+C$y5;Ob4Bi%Jn76eztjJ zc4RT`yA#9A28&Yai-VMKv>Tml8l?@)Gl9e>gFW986?|XK z8l+ONR=~khq`d)_GVxI65ijg!`Xd;X+Z^eg|O5y8iUxPYjBNd_z;7Ou1=wAs#F zJ#c*S8w=nr6l1bA4m>%?X4^et>r5nnWK_;)K-Y%iIoN^x!Z|e(S1IvgZ-iaY5_mkH%iH&Cax!PaFBcjKTm*z;Px6{5sxvfB!GP> zM{+b$vhha6D$X4iBrdBNK3{)Wyg}Kq^U9(=y={R-HPuz+Xhg8$kxX+6w4J!KKW!<> z&KbM2uNj(M=w2|ze9Y!0lD9ACNHNAp8Wc9lVqIU$Dj}t!Q@}NXR1Qu`<+AlaxC2fi z#f*)~G~j`8Luj);fr);!ho zReMb;yPzioH47L1Chf3vjZx4ouZ0}tPt!i3Fgo+Iv|RSfz9Ahn5jT)4F$8jc{lZNrJSp9K}xSEu&^w9)V!KFjI=9NbMRF>N8XkU zQ^Y=P4K%NLk;{l9shG~w+&D$aE>(&GGlhkv>yD5kQe{5ogA?2Suy+ce+T7Wx-Dyne zcwdgbJ5$agF8Q?i_CMlv3%M}!yfpF~#v*3?Q+b-1X;hi zUZ&IVD36#76pKOcjca3 z?e66_2lnaVJzZUPI&2%L($z4NO?#Mxc(v!`Z-PoV=SNv z!yhs$Tzh_pk`*K0S0^Ov?LC=b>XZGk0cFb;c4lg5w1l)NMc(=wgYq^u#^x^!O3{`7Vj&p|Das#iTWBtd zq|9V{rR!}L%0LfpM4u$wKBI_YHjN#Qv5l7tjDh6)yrDj+%!EFOQ4=0(|^ zCn42#8qvIj{2IAwXR?#BBG8}sF3AluT2_(-tP7PxCSDDg- z!qCB+MNVEN>N+!|vRw%e6^i=$nOgT>JpxASv%VqZPQDWwE$LjMYJo}8>WXqr6b^6G z8_VWx{F35ces7nIP-K+-z1=pE)|UtQeDw~vfED`hoqS9=y8Gl#nbjzL+=cF%(A=fz z(f-pG><&kO4u|&JJ8JH6I;|f&k}hi<-}@2}dLP!=51i9Zl7>15Cb2_YOisc6T^ZuR ztFT?2Lu&w6De#prb2|!jbzKbni*FKg2w}INZGE3@SaTY;@R^KIeAw|njBw#L^Zs9@ zAzo@yqZ*wFYHn|C#RUsIs!@?HV-@JDAaO8kYxZMDAv{fBbiit`f|j5mlo-p=k4Q;` zP>JYb3wlb>tD0<5&sP}bW0c=iV#J>?iWLV&*)^zAjPJkOnAux3@x{WJP0W0yT4+19 zCF)-I;1W;=GbmFPC2O6pVeyKQv(G;85h|Omj4cI}aCi z^VbHU8n+4nlL>%Q+T_MB@FqZlEye$f$HdUzMO<-eoCKhZJ?l^gsxKKWD%z5015=2^ zss&c9V()~N1al-K{{TjL#2rL_-oaT5>J~4)ea#OGjk=?Q4RIwAZaqF*It4T%Y^}*l zTb zK@^LLS0J~_98G+ory;pJDP_Q^w#v86yn4S8WTa7fVQw0k^^JhRM zLY0Ksj3h3xjgDV+BhT2nk&3Un5$D0xS|NH&M1TJlbKf4Rs$?WSGOJL-Bk@RE(KOio zaDhujys?6!nBj+|awG04WIrv~_|ivx&rC1B5U$Fn3Rq97!;K(Ry5V=z$>#CulJS|1 zHhgBqdjuEfbX)G>W~C__gYwVg_d?A|el23VRV<|kJS^ziVJdO`*4h3U&+R4Xi+6Pm z>4^cC*4-+Pxr59?-2#}!G*e7vPU}q0&iYKh*TIDOM*Rf)Fux@^%1@>Iay^cif<+8K zwVpvE@7ZJS{r#0^eRDf7F>t)_|GuTyuMYeF7){2y_~4(jCjWa2^=FanSpK5o|G~XN zjPBvAS+i9&9BN@HRjTY0uPahTmUBiy7J(brivNYV0{%t{tuqO1^#-v2e+3l&S*-uh z(TcAPAR(O`-^#z#N<(DTR###}4&tzN2hc^==irQ?_1I+|}pTn<0c`0Up| z9vm_LrNm~7^##>UkEfS*R#D!SI#=Mz>~!6_S2B*{dZ|8(o5rUzb%nC+tLWOS#SPz4 zlmAK91`a6p-+Wa60&@IcrhxdTr}ew#=Wlk4ujiU`WnK;Z4$u5c$NYc9;Xr|*?;An_ zuKIuPMfift|=g-?KaQ3bsAtKbdBc8D}XjIa8Ps z_3%b)AIKx3Nn9oM)hqZetuWkBo1x5Ia)Y=PhODm>>kRG>2y2V4U&-A$ZlS6bw*&+* zr-0?wHT)zQSx;jG*Ip1s37v3HN~}>*u8nhd^)-S|BO^<9+0O>zJ+OW520Zb#47%mA z=wNohLOl%`89+$Du*9hk_YcolN4H|dah*lWx+Se0L?NF>j8z~HlFk~Mp6!bq4x`i# zsw;R0n82qUn0IbLnX?X*GdOP2eTon(Q9NLj$bkFG#nlp4nHFDAgNd!mn=ZOjXDJl@ zyBI^*8g+D%np>Yj^cO^kCBl`J`r#nD;s6j`PjmYRurXn|h2#*2ZiyV>+c!UvPkqX3 z^{kYVB*btl^sRi)B6tP{P$5w2(X8)Ec*4Szu2!tB5f{VUkF3jFeHnol^2KGyK&n3vP*yi#LmOHCUhE z2onhs;|==YWvH^_B^54B3^@^kVL@5;OLer`an=#2a=@yq#1~#CX`y7aQ<6;n64BLP z7oiPXQ@$;tS%X&sX)@?H3y^pL;DMQI9{OxvRfSlpzqw>Dvy`X}s|uj`ih5gwv8!dR z*X$LW8#EqDM>;~j%y4|4^gk~Q*fFQB7uJ<>@(5a+}$JDbnZ(FzLU7jJbQl; z?+!-wl|p91-FuIK#nlnM3QjmJOt76e+!+!k47QGWlY&-i%Mdz91&YpL*yZaTunvj1z03`igu8Vm4GMUVyg4Z%TsB-%Wgbtm1S~ zJ+<+u==)pvo=K~hKEOz*+%(gf~S&%2qrl}V0eCW;dpcH>VKhZC|O z;sFy0$lcA9pJ)8zX823YV@^W-smYvbzPu7CRgzxTyjT<}DuZkASB9O~*-d=)w&yoV zs_tvgb`cfbx+hZ^hk!vnp{kPqHaMmX2314PU_xKRxKoc>SiB z1bio~#NBqY_fQg0=F%HMlM4AGi;Y>=e0xRq{aMp=t2o}4In8_^>B91ba{JBC_6-io zj0P;H1o*;jaaWLB+@kyLVzxlJ`BTnY`R8q|4=#(ps&i4`HF_iY4eYn7-C3L{kkVxi zuRN`tq_PVa10&+|wh4I7t>qWBhHU=6iy8 zzTA?$F5vb6!c!J$_h42lgRBX3D?B~K!uwwms&|~3MU$D)>Qs}OQvxdZ8;Z2C^dPTU z%}Rf?cx83o>Z1C!FJ#stc$j5aF8=Ea^!^6ubT9+obfXAP0PCozjH+MZb} zd>;%M0oHc2&rl2)|MST!-6TDG=Wfg8u1Mu3p?vfU7SCrU8gbEXHpVB^-uc?rle+oF zxoF+UQEknY1Et1DlG8x}A!QjwnW0a9GRCy>-@MgYH+qwn*|q(l_G4``J;ar`OXKgx z`t@6s_aD~a6MTEUE4R^w-sOX&$ocK6;>2q&avWeJGWN55j!Rvs!tTM+Wi=O;^J$fx zDE;iTlA71MZD(WJgnCfs$aUfb99n8Uq^+ND3SBk@97Otpw3S z%z9Vm|Kss*YO3FxD-M@?57t%e9>B>kMA(i+Vq$uYBGaMbPI~|f61G$xUA=%os#wq# z4?y$ugjey)M_`J@Vo%5RD9tMwdG+7YqA$mh| z$(Rk+KE}ZfZ%2Hgg3Q<#HnaEaKoUlzc+$%Jl6WUTQR-pw8I{}tx>uyiz|N~+VMaW- z3ub2&DdC;O-=Hc|J4~z&?+M~Ho zl*|d3&Kq|#ON82}I_|&Y;t9X-5Xj+m9gUyNvW|m&fWqC?qxQ42D0w!Zf;Rc$D1)I| zP?i=W>jFcLg!((Vbk_QGR^36-2hH?Fe{1CikF%>H#}0LM7$+Uabo!s5j9-3R-eu>N z9t`#AIIi#tq2cs^3qCF47_B^}gJ`sw%^>Zy0f9?nMW;aEaPvg6lG$#An|AB+njebs z0;#L5MK|N-)Hxr;?^P)u>?UrKY&#POrxqjD1NN7CQ)cS7t%(`?EEk=MS8%@ON@)hW zclE1d?FHNz`JtKu#uC-ucdiwyYe<_f(EI%TymWKl72lnm>U^hQt9(wZy7_u#oa7Dj zhiXzY{L@Urv17Tu8~eu=R6S=MF5Rijt91WpY0C@c-;Gy1Ab;=-Weu=Ss0%%Pj&FOu z;D)zIm0gd>eV!90c;B{*A)^m5GdAfGYW)^=B3!oxlciM(zYmc;oTXJr$aSpH z&CR-+*VbdyvZCerk!(o~uT$fFF1jj9F+|`1B=s$#vv+)cN6Y9V8+FT5fO$pSU_|gB z+rYwu4ylAT6C@X^ zbuy$)H)m@Q?up<&{mjsl;d`nM&+xUmQl5)}r&<$N+JQ@jcTy>kkYMJJr>*aiJ>%Z< z;3WQ3H^>R{c5lD-%ba@3VlC5f(_C)uO5q z5+iK7tI`Bn*PYAEzjnK^I3#a3en6Ri3NRLgxq4M;H@DLM3M=)A4)Dc_y0z8*wUV7q z@L;5ufgGY&r#WFE?6G42_hXV?XEl!y%8|&tY1v$WMihA;%0*i=_=oYyn%T}eFwZRA zH)AmnA}oDR{L>f7ak%D_9X{OEiq*X%X?bKl93f|mX~oSfpgqY2qaEZ4XcXlw+tS0{ z3m@_5fFG=2$REhFONgUydPr8W^qwK2467xs0Oiy2W*mUP!1a)AL-_N(bVkPg}~9{Woh+tc^f@?N|-cPu5zM zJQYJmW5%6WA&An>=hPmFaVqc?BWK)0Te!A!Su~X<$)I4dwf4NWIGAEG;=0oU$ZMMG zW;&;A#@y73f30ujzd~DDvy&<-uh4z@G1N8=ETma*;p^5-CwZZhf@^CGi`+p|;10oO z#Xf5EWl8utIxoJR3ATdPKpURHeLH$lB|879yZ7^eF!Jf}0rA3DSD=+;?sz0*J24kk!-Q|5F9xixI4M_C+(+uRB-2`)xJ z1=~>>vX8b0X`I|(8vP2JxNN^nmG+1-Gh~d**49|nTD5pMNXsd%FhBxHlSoMZ=v#}o zj`Xy`POm74Caujghi~>>T4a|lqnE3i$*Li-vF#EDr9OADlTl4Udu|v@PkqWvpH#?v z?x4LxQg!oD+E^dY#h6kdyQ}OmEYZ|zcl~k;YBoi__3oGD7K*Ux#_3WQ)%iwg&lx^W zIFEh^=Ku(&qg(b&7ES|dfS`$nTt=+^y>1OMqzts_W4%8v>70sXvNTwkSokpG)l1@8FYx%i* zil5fLwJ52&IefrYwaU2rDvp4oPNum6IHp`2EY@SbGFO` zDdD~-8!y_s$r=lAu1F!7^vkeZw73G(XZ$+Qp4i_Jmp-R_rs$0I$T=-!0QRtl$C=%! zKUQ+uj#J5AHT}>Ri|Zj>t=814d~S*@&c04o;G%1#45SpLtf~ru<CBKp<&U#m!=v8p@--MbZCbQcFHV&+#E5; z+?-zS7zMd#S*9(Xe;Fy6mSr_~nF6R30c7Fc@F_!XaA~|Jl%;9So`RD3eq+fM)Og52 zJ(Zem?&a10^wZAh%NNe-_+HoA)z(3LyuC4+yT%CQb%~Z+)$qP}L9^828b(Wt-EzZ1 zDv{l|OF+0mib6c3D)E^Av#c%Ikl=v_5Uho(kndxj4`Ns%b$Mu zn?*kx$pDv?&t@3^^*P{%pEm;eZo%Y~2HsS-gamv@(T988E>N@Xjc~GGc18{*R!Z*l zL2rv+933iL9M%w}KbH2)l|ZOPAsy)$DXiFMqW4pBPv8>AoWd#(@L!nENiM9h?;tik zp4gtRhREjA1_PUKfNTLLJ`WwP`4c*3mL~Z~g?Qr?(?45I5zO#4TZf|4wtKL(Cjkkp zm6fs1=%A)g7}0iiB_XeC3no9d->FSFu8=kZ*pypIZ+3?8G`N*gBv1Uzg`5c}398F= zl$A9Yfk=CTBrGi9;%3-1E?^f(=@JG~y7U`dJM+pcTi#u*6x;!-SwiDjCOPOQTYYuF z2SFem9jC3eD*IZ!=LVda&?__P%~qZ?0Obit+)`j%kqR3ya?G+}Qe#lJos6g(!IB^t zV<3R!V@8=HXA-;j=gOEr{MyWoxntu|Pgt<;Y4MP+}K%|M4pXLLBkbPIG44d!GhdPSNBaZ|~wo#&N~mY*Tc zS6K9GyhgVw`S%M(97&qVnuppsTR^j<_|CZn0#2hzkDd#r8NEz$Cs6_`kby>Q*?-{2 z!)m;FrRB7c)e%Pu4+gqs_emF z7(8v}0&IYE)NsD9x6ZwUys*kS5M*_EUtem5q|%5p<=qm9#wx2w6OS^;MN~)gGfaaZ zDtjnrRaRDZ>Gs?ptA|QB3l&6Zc63kk4fLrPRvM(H%kp4}O<&IM=Ts}bn}%&|g_4a% zM6RrQQC>lg(k~w$kS(K16uBo)WgfO}HP$c(;-i7*kif|6VNRY;a^`Oruau6Lr>a@1 z2+m+Sh%H4Swjg1evs?j{0bOTCEX&hy0NogA@35^p~t%Fg69KViPi*j{g|GH za1X?j9O5CB$u9=CVZMSdHwktVa~|zC-|u({S@$}t)^WWpnb0bG@FP=5q(fo5b%-Xp zgE(-%HDPOltW{;{AXx6%1mnNTb5d}3!t+x5E@I9(oA(>-NhEWXa!QN@;%F!F#$z9f zSD}vAyBhQ)<;2UlJ?6s$#Ehgj^`}Lguo-AWmHG$To8ty)c@G9WKEVr_hDow2d$oOtIg%H_a z7blSB36P75hrb_eC49V8o^Km_lAhUP>)a-`N5^FKsmH*(u|j{+>K$+!n-9s4w_cs> z^Zi_bzMOug8MizlszMhu88NCp1A970uKIk(Vr9;i-I+@WIQ^bNDq1wab7=$@cGiAY zm~r@s)8J8akqSC7YhRpp>U4K|OfDnw7r$C3Y9eO@Kq0?(F{nNTc{krfF~kZ#SXIZ< zi(wdl?pTFb*SP0Z37P=hprCanWYVYmNP!}vIH#-d#Jm}BM>*>k@lje%X{BQPUQ^dX zM19_>lxOO(B3|<`7&{udfT-Y))|&=?{fqeFHb1DvB?dyA$)Rl z|2Dm&YCZC?KzidXU6vP~CroaT`8Cw>K$q-i*)KY+sYR+lR-mz!?mng415(TMWR6k# zvv1PbcRupf*Av80G?&Ipiw%g8C0Z_)SfN@GTgHd`U9RH(wZY#9A16sgx4|;e%?cs| zxJ-S+Lka`TUiL+P?H+maJpS@VWE2ljg|2p`vOEyh49RSgxw_ zO5!lg1bd0(XHy}|i7Nfrj%8br1V#%P>A^OC?(LPy`&Q@U!L;ZXu{ub; zq7S%3BYa_l-tU}$o5=0Vdy$Nxp(1>Q7j2zQ59QL|P7*kQCJ`=t$Sim5xdhc9uv64j zTkScSfIp*ap*>IrzDnXc`MUcK`s1n(>$;aZ03cTJ97(wD;VTbikKEmiQQZ;sZNhF& zzJ<2abzSD;K;hJG=WK5`N<;2xei3i1+u4%p&D)Bs+d6)0b#l_XlGurVN{#M@@AA+Z zRHfV-G0Asa?4@iDB4xwnmd6Wc8cH6OicN+z6j}lzVi!yS|wI7-_t!)?IsJ z7!LRI@ySBxy1f=`(>tv1;;5^xf3TY1oZ;ncFr?_aG~!upI@^)gIWC?ZAWVK!6GGU~ zJgTjWwBIZ&6IWZ~Pg_#ng3? z#-wwt?&}SXj}dRxy&gl~d`P`P`$7dnmui}hF$Z(jz-2!^`ILZ^uMI-V?k-GyAG!Se zWFhg7w97%_lJ4b{t`imSX)BT!sJw&n88Vr%Gj9fqUnTnfR(G*m(^s4nor*pfUXGrG zhU~4R93*Lo?W`1%-c(X@sJ$825^zQjnc@+VXB2(4)AX{NW*8#hgY=Sa2cD;B+d1pc zc~Lb~$#O52%ZJ9sA?dPGM%{BO7Y_*uyTk2%nF>VFOqiVoF#hUD*X|0-Sn`!DA8&`r zW%wU0ZX)z|-)rbd9a}?%d%SX7nz?B8q;f6M;x4k24N8+uF~07+-gUo%$NDNtR$mTn;}=Y912b8(p?C8do~iEy5i*dtaOaT>Yk zzSYAjS6UgG?S}|Ys#8*&`mt?;vg3QeNO`GQ3M7HGiZ_Ov%pBO4D}Gtg`WG_tZMW5t z*)JpaZE`XE{Z@ex{YQ~Xd?<m#_;f3b8pT(kIer02GDfWC}cjuE3d8Mbb^oP=ezAwZf~Ry#(nf&@)rNGAw#Tf z7~HCLQlNv@NSGrX&Gm3to)*B_Fq~F=G(g^&0RLjC5Y=$^IUABCWpo(>a_Jb~%x=9x ztZ&{pWSAsh+{z8j3kvK@?C}6+gL`=*zi31DF^`^oSgg?lP1oW8f=*98p{2<$ss!BT zExekzzj_!ED$s(XFD)(~|R-TKj)U>s|7kc7p`s;SxTw&k%>xR|f( zmoV z+dgnx)gflnfcKBDoD3GwwDbQl%jg^b{V6EGx-`%foCV2f{v-4aoLK%T6~eNGGLxDT zKt7sv>&JZ-NM5)6huvGS+PV(d{0FD2_r0@KZ}l+5r3^kK{r0u#^>_baMOB|b`=m{O#t(uDfL{2sFTCJSwr`iPzk`0U v|G(Z%PVHZNyZ&w`JIue73m~nsSR1xOu`#QiB|<-ef0tD>F65oN@yGuGE0}@u literal 0 HcmV?d00001 diff --git a/pyproject.toml b/pyproject.toml index ffb66b274a..f892fe9de5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "3.7.0.0" +version = "3.8.0.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index 81bb8cfb96..d60e28468e 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -134,7 +134,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '3.7.0' +this.__version__ = '3.8.0' from cutlass.backend import create_memory_pool from cutlass.emit.pytorch import pytorch diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 6ae493b962..1e944a08c1 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -65,11 +65,15 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, tile_scheduler = TileSchedulerType.Default + + , ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None + ): kinds_3x = { GemmKind.Universal3x, GemmKind.SparseUniversal3x, + GemmKind.BlockScaledUniversal3x, } self.is_3x = gemm_kind in kinds_3x self.prefix = "3x" if self.is_3x else "" @@ -82,6 +86,14 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, self.C = C self.D = D + + if self.gemm_kind == GemmKind.BlockScaledUniversal3x: + self.ScaleFactorA = ScaleFactorA + self.ScaleFactorB = ScaleFactorB + self.ScaleFactorD = ScaleFactorD["tensor"] + self.ScaleFactorVectorSize = ScaleFactorD["vector_size"] + + if self.D == None: self.D = self.C @@ -150,6 +162,7 @@ def core_name(self): OpcodeClass.TensorOp, OpcodeClass.WmmaTensorOp, OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, ] is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops @@ -207,6 +220,23 @@ def extended_name_3x(self): element_c = DataTypeNames[self.C.element], element_d = DataTypeNames[self.D.element], core_name = self.core_name()) + + if self.gemm_kind == GemmKind.BlockScaledUniversal3x: + d_type_names = DataTypeNames[self.D.element] + + if self.ScaleFactorD.element != DataType.void: + d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names + + extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.ScaleFactorA], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.ScaleFactorB], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + core_name = self.core_name()) + return extended_name def datatype_name_3x(self): @@ -247,6 +277,11 @@ def kernel_schedule_name_3x(self): # Generates a short string representing underlying epilogue schedule type def epilogue_schedule_name_3x(self): + + if self.gemm_kind == GemmKind.BlockScaledUniversal3x: + if self.ScaleFactorD.element != DataType.void: + return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout] + return EpilogueScheduleSuffixes[self.epilogue_schedule] # Generate a short string representing the operation class @@ -769,6 +804,32 @@ def instance_template(self): ${compile_guard_end} """ + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + 'epi_vs' : str(operation.ScaleFactorVectorSize), + 'element_d': str(DataTypeTag[operation.D.element]), + 'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]), + 'layout_sfd': LayoutTag[operation.ScaleFactorD.layout], + 'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor], + 'element_accumulator': str(DataTypeTag[operation.accumulator_type()]), + 'element_scalar': str(DataTypeTag[operation.accumulator_type()]), + 'element_c': str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + # def emit(self, operation): _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") @@ -778,6 +839,12 @@ def emit(self, operation): opcode_class_main = operation.tile_description.math_instruction.opcode_class opcode_class_epi = opcode_class_main + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized: + opcode_class_epi = OpcodeClass.TensorOp + + tile_shape = operation.tile_description.tile_shape instruction_shape = operation.tile_description.math_instruction.instruction_shape cluster_m = operation.tile_description.cluster_shape[0] @@ -790,6 +857,23 @@ def emit(self, operation): cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0] cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + + # Shape passed to epilogue builder + is_sm100_kernel = (operation.arch == 100) + if is_sm100_kernel: + cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1 + if cluster_m <= 0: + cta_m = cta_m // cta_m_per_mma_instruction + + if opcode_class_main in [OpcodeClass.TensorOp + , OpcodeClass.BlockScaledTensorOp + ]: + tile_shape_main_m = instruction_shape[0] + tile_shape_main_n = instruction_shape[1] + tile_shape_epi_m = cta_m + tile_shape_epi_n = cta_n + + # stage count set to zero indicates builder automatic stage selection if operation.tile_description.stages > 0: stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" @@ -811,14 +895,37 @@ def emit(self, operation): 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor], } epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + + if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + else: epilogue_functor = self.epilogue_functor.emit_declaration() + + if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + # # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] + if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] + element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' + element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' + + values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, @@ -1184,6 +1291,7 @@ def __init__(self, operation_path, configuration_name): GemmKind.Universal: EmitGemmUniversalInstance, GemmKind.Universal3x: EmitGemmUniversal3xInstance, GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance, GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, GemmKind.Grouped: EmitGemmGroupedInstance @@ -1195,6 +1303,7 @@ def __init__(self, operation_path, configuration_name): GemmKind.Universal: 'GemmUniversalOperation', GemmKind.Universal3x: 'GemmUniversal3xOperation', GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation', + GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation', GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', GemmKind.Grouped: 'GemmGroupedOperation' @@ -1255,6 +1364,7 @@ def __enter__(self): ("gemm_operation.h", None), ("gemm_operation_3x.hpp", None), ("sparse_gemm_operation_3x.hpp", None), + ("block_scaled_gemm_operation_3x.hpp", None), ("cutlass/arch/wmma.h", None), ("cutlass/numeric_types.h", None) ]) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 3fa49eae32..c75f334206 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -209,6 +209,15 @@ def CreateGemmUniversal3xOperator( gemm_kind = GemmKind.Universal3x element_compute = data_type.get("epi_type", data_type["acc_type"]) + + if "sf_type" in data_type: + gemm_op_extra_args["ScaleFactorA"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]), + "vector_size" : data_type["sfd_type"]["vector_size"]} + gemm_kind = GemmKind.BlockScaledUniversal3x + + operation = GemmOperation( gemm_kind, tile_description.minimum_compute_capability, tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, @@ -6509,217 +6518,2124 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): # + +# Blackwell SM 100 generators + ################################################################################################### -def GenerateSM90_Conv3x(manifest, cuda_version, - log_indent_level: int = 0): - """ - Generate CUTLASS 3 convolution kernel(s) for SM90. +def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: + if DataTypeSize[data_type] < 8 and is_f8f6f4: + return int(128) + return int(16 * 8 / DataTypeSize[data_type]) - This is meant to be called from GenerateSM90. - """ - log_debug_line('GenerateSM90_Conv3x', log_indent_level) - log_indent_level = log_indent_level + 1 +sm100_cluster_shape_1sm = [ + [4,4,1] +] - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): +sm100_cluster_shape_2sm = [ + # cluster_m % 2 == 0 for 2sm + [4,4,1] +] + +def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return - minimum_compute_capability = 90 - maximum_compute_capability = 90 + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + ] + + data_types = [ + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + min_cc = 100 + max_cc = 100 + math_instructions_1sm = [ + # tf32 -> f32 + MathInstruction( + [64, 128, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] - spatial_dims = (2, 3) + cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1] + ] - # This function only generates kernels that use TMA. - byte_alignment_required_by_tma = 16 - tma_byte_alignments = { - 'A': byte_alignment_required_by_tma, - 'B': byte_alignment_required_by_tma, - 'C': byte_alignment_required_by_tma, - } + tile_schedulers = [ + TileSchedulerType.Default + ] - # For tuples of one element, the element needs to end with comma. - all_byte_alignments = ( - tma_byte_alignments, - ) + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - # MMA shapes (MMA_M, MMA_N, MMA_K): - # - # Different hardware MMA instructions may have different MMA shapes. - # This function may generate kernels with different MMA shapes for - # different data types, either because the hardware only supports - # certain shapes for certain types, or for performance reasons - # (CUTLASS doesn't need to generate all valid kernels for the - # profiler library, just the best-performing ones). - # - # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) - # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, - # where 4, the "number of MMA instructions per tile," is determined - # through some combination of modeling and experiment. - # - # For performance on sm90, generally CUTLASS generates 64x128 - # instead of 128x64. - mma_64x64x16 = ( 64, 64, 16) - mma_64x64x8 = ( 64, 64, 8) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) - num_mma_per_tile = 4 + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [128, 128, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] - # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, - # but not included, because they tend not to perform as well. - cluster_shapes = ( - (2, 1, 1), - (1, 2, 1), - ) + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + ] - fp16 = DataType.f16 - bf16 = DataType.bf16 - fp32 = DataType.f32 - s8 = DataType.s8 - s32 = DataType.s32 + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - # When generating kernels, the usual way is to specify 4 types, - # (A, B, Acc, C/D). Tests instead have 5 types, - # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), - # where ElementCompute is also called 'epi_type', - # and corresponds to the type of epilogue activations. - # This script maps tests' 5 types to 4 types - # by making ElementCompute the same as ElementOut. + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto - fp16_fp32_fp16_fp32 = { - 'a_type': fp16, # ElementAct(ivation) - 'b_type': fp16, # ElementF(i)lt(er) - 'c_type': fp32, # ElementAcc - 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) - 'acc_type': fp16, # ElementAcc - 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) - } - fp16_fp32_fp32_fp32 = { - 'a_type': fp16, - 'b_type': fp16, - 'c_type': fp32, - 'd_type': fp32, - 'acc_type': fp32, - 'epi_type': fp32, - } - fp32_fp32_fp32_fp32 = { - 'a_type': fp32, - 'b_type': fp32, - 'c_type': fp32, - 'd_type': fp32, - 'acc_type': fp32, - 'epi_type': fp32, - } - s8_s32_s32_s32 = { - 'a_type': s8, - 'b_type': s8, - 'c_type': s32, - 'd_type': s32, - 'acc_type': s32, - 'epi_type': s32, - } + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - # Other NVIDIA libraries may have the habit of specifying data types like this. - bf16bf16_bf16f32_f32 = { - 'a_type': bf16, - 'b_type': bf16, - 'c_type': fp32, - 'd_type': fp32, - 'acc_type': fp32, - 'epi_type': fp32, - } - f16f16_f16f16_f16 = { - 'a_type': fp16, - 'b_type': fp16, - 'c_type': fp16, - 'd_type': fp16, - 'acc_type': fp16, - 'epi_type': fp16, - } - f16f16_f16f32_f32 = { - 'a_type': fp16, - 'b_type': fp16, - 'c_type': fp16, - 'd_type': fp16, - 'acc_type': fp32, - 'epi_type': fp32, - } - f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 +def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return - i8i8_i8i32_f32 = { - 'a_type': s8, - 'b_type': s8, - 'c_type': s32, - 'd_type': s32, - 'acc_type': s32, - 'epi_type': s32, - } + # layouts for ABC and their alignments. C alignment will be set later based on output type + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + math_instructions_1sm = [ + # f16 -> f16 + #MathInstruction( + # [64, 64, 16], + # DataType.f16, DataType.f16, DataType.f16, + # OpcodeClass.TensorOp, + # MathOperation.multiply_add), + MathInstruction( + [64, 128, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # f16 -> f32 + MathInstruction( + [64, 128, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # bf16 -> f32 + MathInstruction( + [64, 128, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] - # Each element in the outermost iterable is one combination of - # - # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) - # - # for which to generate a kernel. spatial_dimension is the spatial - # dimension of the convolution: either 1, 2, or 3. byte_alignments - # is a triple of required minimum byte alignments for A, B, and C. - # - # Note that itertools functions produce a single-pass generator. - # The code doesn't need a multipass iterable, but if one did, one - # could call `tuple` or `list` on the generator. - # - # While this happens to use the same cluster sizes for each element, - # the code doesn't require that. Different convolution kinds, data - # types, or mma sizes might have different optimal cluster sizes. - combinations_of_parameters = chain( - # The following are all the kernels exercised in the unit tests. - # Please try to keep in sync with the unit tests. - product( - ( - ConvKind.Fprop, - ), - spatial_dims, - ( - fp16_fp32_fp16_fp32, - fp16_fp32_fp32_fp32, - s8_s32_s32_s32, - ), - all_byte_alignments, - ( - mma_64x64x16, - ), - cluster_shapes - ), - product( - ( - ConvKind.Fprop, - ), - spatial_dims, - ( - fp32_fp32_fp32_fp32, - ), - all_byte_alignments, - ( - mma_64x64x8, - ), - cluster_shapes - ), - product( - ( - ConvKind.Dgrad, - ConvKind.Wgrad - ), - spatial_dims, - ( - fp16_fp32_fp16_fp32, - fp16_fp32_fp32_fp32, - ), - all_byte_alignments, - ( - mma_64x64x16, - ), - cluster_shapes - ), - # Kernels not necessarily in the unit tests, but used elsewhere - # and thus useful to have generated for profiling. They may - # duplicate kernels above. All of them are 2-D. In general, + cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1],[4,4,1] + ] + + tile_schedulers = [ + TileSchedulerType.Default + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + # 128x64x16 + #MathInstruction( + # [128, 64, 16], + # DataType.f16, DataType.f16, DataType.f16, + # OpcodeClass.TensorOp, + # MathOperation.multiply_add), + # 128x128x16 + MathInstruction( + [128, 128, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + + # 128x256x16 + MathInstruction( + [128, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + + # 256x128x16 + MathInstruction( + [256, 128, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + + # 256x256x16 + MathInstruction( + [256, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + math_instructions_1sm = [ + # inst 64x128 + MathInstruction( + [64, 128, 32], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 128x128 + MathInstruction( + [128, 128, 32], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 128x256 + MathInstruction( + [128, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + ] + + tile_schedulers = [ + TileSchedulerType.Default, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + # inst 128x128 + MathInstruction( + [128, 128, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 128x256 + MathInstruction( + [128, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 256x128 + MathInstruction( + [256, 128, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 256x256 + MathInstruction( + [256, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ] + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM100 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], + ] + + instruction_sizes_1sm = [ + [128, 128, 32], [128, 256, 32], # Mixed F4/F6/F8 block scaled only supports M=128 for 1SM cases + ] + + instruction_sizes_2sm = [ + [256, 128, 32], + [256, 256, 32], + ] + + ab_types = [ + DataType.f4, DataType.f6, + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + cluster_shapes_1sm = [ + [1,1,1], + # [1,2,1], + [2,1,1], + # [1,4,1], + [4,4,1] + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]) + ) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + [4,1,1], + # [4,2,1], + [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + + # Set alignment d based on Destination format. + for data_type in data_types: + for layout in layouts: + # alignment for a + layout[0][1] = get_tma_alignment_elt(data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(data_type["d_type"]) + for tile in tile_descriptions: + math_inst = tile.math_instruction + # Filter some kernels that does not meet the alignment requirements. + if layout[0][0] == LayoutType.ColumnMajor: + if math_inst.instruction_shape[0] // 2 % layout[0][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[0][1] != 0: + continue + + if layout[1][0] == LayoutType.RowMajor: + if math_inst.instruction_shape[1] // 2 % layout[1][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0: + continue + + if math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]) + ) + else: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]) + ) + + + +def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + ] + + instruction_sizes_1sm = [ + [128, 128, 64], + ] + + instruction_sizes_2sm = [ + [256, 128, 64], + [256, 192, 64], [256, 256, 64] + ] + + ab_types = [ + DataType.f4, + DataType.e2m1, + ] + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) # UE8M0 scale factor + ) + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) # UE4M3 scale factor + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) # UE8M0 scale factor + ) + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) # UE4M3 scale factor + ) + + cluster_shapes_1sm = [ + [1,1,1], + # [1,2,1], + [2,1,1], + # [1,4,1], + [4,4,1] + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm] + fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] + , tile_schedulers=tile_schedulers(data_type["sfd_type"]) + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] + , tile_schedulers=tile_schedulers(data_type["sfd_type"]) + ) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + [4,1,1], + # [4,2,1], + [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto] + fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] + , tile_schedulers=tile_schedulers(data_type["sfd_type"]) + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] + , tile_schedulers=tile_schedulers(data_type["sfd_type"]) + ) + + + + +def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + math_instructions_1sm = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + ] + + tile_schedulers = [ + TileSchedulerType.Default, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + + +# +# Kernels using the stream-K tile scheduler. +# A reduced set of kernels is generated for these schedulers to reduce functional +# and perofrmance testing time. +# + +def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + + ] + + data_types = [ + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + } + ] + + min_cc = 100 + max_cc = 100 + math_instructions_1sm = [ + MathInstruction( + [128, 256, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_1sm = [ + [1,2,1], [1,1,1], [1,4,1], [4,4,1] + ] + + tile_schedulers = [ + TileSchedulerType.StreamK, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [256, 256, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_2sm = [ + [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. C alignment will be set later based on output type + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + math_instructions_1sm = [ + MathInstruction( + [128, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [ + [1,2,1], [1,1,1], [4,4,1] + ] + + tile_schedulers = [ + TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + } + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + } + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [256, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_2sm = [ + [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + } + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + } + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + + +def GenerateSM100(manifest, cuda_version): + # + # Dense Gemm + # + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) + GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version) + + GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version) + + GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) + # + # Block Scaled Gemm + # + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + +################################################################################################### + +def GenerateSM90_Conv3x(manifest, cuda_version, + log_indent_level: int = 0): + """ + Generate CUTLASS 3 convolution kernel(s) for SM90. + + This is meant to be called from GenerateSM90. + """ + log_debug_line('GenerateSM90_Conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + minimum_compute_capability = 90 + maximum_compute_capability = 90 + + spatial_dims = (2, 3) + + # This function only generates kernels that use TMA. + byte_alignment_required_by_tma = 16 + tma_byte_alignments = { + 'A': byte_alignment_required_by_tma, + 'B': byte_alignment_required_by_tma, + 'C': byte_alignment_required_by_tma, + } + + # For tuples of one element, the element needs to end with comma. + all_byte_alignments = ( + tma_byte_alignments, + ) + + # MMA shapes (MMA_M, MMA_N, MMA_K): + # + # Different hardware MMA instructions may have different MMA shapes. + # This function may generate kernels with different MMA shapes for + # different data types, either because the hardware only supports + # certain shapes for certain types, or for performance reasons + # (CUTLASS doesn't need to generate all valid kernels for the + # profiler library, just the best-performing ones). + # + # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) + # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, + # where 4, the "number of MMA instructions per tile," is determined + # through some combination of modeling and experiment. + # + # For performance on sm90, generally CUTLASS generates 64x128 + # instead of 128x64. + mma_64x64x16 = ( 64, 64, 16) + mma_64x64x8 = ( 64, 64, 8) + + num_mma_per_tile = 4 + + # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, + # but not included, because they tend not to perform as well. + cluster_shapes = ( + (2, 1, 1), + (1, 2, 1), + ) + + fp16 = DataType.f16 + bf16 = DataType.bf16 + fp32 = DataType.f32 + s8 = DataType.s8 + s32 = DataType.s32 + + # When generating kernels, the usual way is to specify 4 types, + # (A, B, Acc, C/D). Tests instead have 5 types, + # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), + # where ElementCompute is also called 'epi_type', + # and corresponds to the type of epilogue activations. + # This script maps tests' 5 types to 4 types + # by making ElementCompute the same as ElementOut. + + fp16_fp32_fp16_fp32 = { + 'a_type': fp16, # ElementAct(ivation) + 'b_type': fp16, # ElementF(i)lt(er) + 'c_type': fp32, # ElementAcc + 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) + 'acc_type': fp16, # ElementAcc + 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) + } + fp16_fp32_fp32_fp32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + } + fp32_fp32_fp32_fp32 = { + 'a_type': fp32, + 'b_type': fp32, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + } + s8_s32_s32_s32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + } + + # Other NVIDIA libraries may have the habit of specifying data types like this. + bf16bf16_bf16f32_f32 = { + 'a_type': bf16, + 'b_type': bf16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + } + f16f16_f16f16_f16 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp16, + 'epi_type': fp16, + } + f16f16_f16f32_f32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp32, + 'epi_type': fp32, + } + f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 + + i8i8_i8i32_f32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + } + + # Each element in the outermost iterable is one combination of + # + # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) + # + # for which to generate a kernel. spatial_dimension is the spatial + # dimension of the convolution: either 1, 2, or 3. byte_alignments + # is a triple of required minimum byte alignments for A, B, and C. + # + # Note that itertools functions produce a single-pass generator. + # The code doesn't need a multipass iterable, but if one did, one + # could call `tuple` or `list` on the generator. + # + # While this happens to use the same cluster sizes for each element, + # the code doesn't require that. Different convolution kinds, data + # types, or mma sizes might have different optimal cluster sizes. + combinations_of_parameters = chain( + # The following are all the kernels exercised in the unit tests. + # Please try to keep in sync with the unit tests. + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + s8_s32_s32_s32, + ), + all_byte_alignments, + ( + mma_64x64x16, + ), + cluster_shapes + ), + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp32_fp32_fp32_fp32, + ), + all_byte_alignments, + ( + mma_64x64x8, + ), + cluster_shapes + ), + product( + ( + ConvKind.Dgrad, + ConvKind.Wgrad + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + ), + all_byte_alignments, + ( + mma_64x64x16, + ), + cluster_shapes + ), + # Kernels not necessarily in the unit tests, but used elsewhere + # and thus useful to have generated for profiling. They may + # duplicate kernels above. All of them are 2-D. In general, # CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the # hardware permits 128 x 64. ( @@ -6973,6 +8889,13 @@ def define_parser(): GenerateSM80(manifest, args.cuda_version) GenerateSM89(manifest, args.cuda_version) GenerateSM90(manifest, args.cuda_version) + + + blackwell_enabled_arch = args.architectures == "100a" + if blackwell_enabled_arch: + GenerateSM100(manifest, args.cuda_version) + + if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index c00992f29f..dc8a6f96f4 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -83,6 +83,14 @@ class DataType(enum.Enum): s64 = enum_auto() e4m3 = enum_auto() e5m2 = enum_auto() + f8 = enum_auto() + f6 = enum_auto() + f4 = enum_auto() + e3m2 = enum_auto() + e2m3 = enum_auto() + e2m1 = enum_auto() + ue8m0 = enum_auto() + ue4m3 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() @@ -117,6 +125,9 @@ class DataType(enum.Enum): DataType.f64: 'd', DataType.cf32: 'c', DataType.cf64: 'z', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', } # @@ -137,6 +148,14 @@ class DataType(enum.Enum): DataType.s64: "s64", DataType.e4m3: 'e4m3', DataType.e5m2: 'e5m2', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', + DataType.e2m3: 'e2m3', + DataType.e3m2: 'e3m2', + DataType.e2m1: 'e2m1', + DataType.ue8m0: 'ue8m0', + DataType.ue4m3: 'ue4m3', DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", @@ -178,6 +197,14 @@ class DataType(enum.Enum): DataType.s64: "int64_t", DataType.e4m3: 'cutlass::float_e4m3_t', DataType.e5m2: 'cutlass::float_e5m2_t', + DataType.f8: 'cutlass::type_erased_dynamic_float8_t', + DataType.f6: 'cutlass::type_erased_dynamic_float6_t', + DataType.f4: 'cutlass::type_erased_dynamic_float4_t', + DataType.e2m3: 'cutlass::float_e2m3_t', + DataType.e3m2: 'cutlass::float_e3m2_t', + DataType.e2m1: 'cutlass::float_e2m1_t', + DataType.ue8m0: 'cutlass::float_ue8m0_t', + DataType.ue4m3: 'cutlass::float_ue4m3_t', DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", @@ -219,6 +246,14 @@ class DataType(enum.Enum): DataType.s64: 64, DataType.e4m3: 8, DataType.e5m2: 8, + DataType.f8: 8, + DataType.f6: 6, + DataType.f4: 4, + DataType.e2m3: 6, + DataType.e3m2: 6, + DataType.e2m1: 4, + DataType.ue8m0: 8, + DataType.ue4m3: 8, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, @@ -447,6 +482,22 @@ class KernelScheduleType(enum.Enum): TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() ImplicitTmaWarpSpecializedSm90 = enum_auto() + + TmaWarpSpecialized1SmSm100 = enum_auto() + TmaWarpSpecialized2SmSm100 = enum_auto() + + + BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() + BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + + Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + # KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', @@ -462,6 +513,22 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', + KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', + + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', + + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + } # @@ -479,6 +546,22 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', + + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', + + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + } class EpilogueScheduleType(enum.Enum): @@ -487,6 +570,9 @@ class EpilogueScheduleType(enum.Enum): NoSmemWarpSpecialized = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecialized1Sm = enum_auto() + TmaWarpSpecialized2Sm = enum_auto() + # EpilogueScheduleTag = { EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', @@ -494,6 +580,8 @@ class EpilogueScheduleType(enum.Enum): EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', + EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', + EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm', } # @@ -503,13 +591,18 @@ class EpilogueScheduleType(enum.Enum): EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1Sm: '', + EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', } class EpilogueFunctor3x(enum.Enum): LinearCombination = enum_auto() + LinearCombinationBlockScaleFactor = enum_auto() + # EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', + EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', } class TileSchedulerType(enum.Enum): @@ -595,12 +688,15 @@ class OpcodeClass(enum.Enum): TensorOp = enum_auto() WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() + BlockScaledTensorOp = enum_auto() + OpcodeClassNames = { OpcodeClass.Simt: 'simt', OpcodeClass.TensorOp: 'tensorop', OpcodeClass.WmmaTensorOp: 'wmma_tensorop', OpcodeClass.SparseTensorOp: 'sptensorop', + OpcodeClass.BlockScaledTensorOp: 'bstensorop' } OpcodeClassTag = { @@ -608,6 +704,7 @@ class OpcodeClass(enum.Enum): OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp', + OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp' } ################################################################################################### @@ -688,6 +785,8 @@ class GemmKind(enum.Enum): PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() + BlockScaledUniversal3x = enum_auto() + # GemmKindNames = { GemmKind.Gemm: "gemm", @@ -698,6 +797,7 @@ class GemmKind(enum.Enum): GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped", + GemmKind.BlockScaledUniversal3x: "gemm_block_scaled" } # @@ -871,6 +971,8 @@ class GroupMode(enum.Enum): GroupMode.Depthwise: 'depthwise', } +DynamicClusterShape = [0, 0, 1] + ################################################################################################### # @@ -879,6 +981,7 @@ def __init__(self, instruction_shape, \ element_a, element_b, element_accumulator, \ opcode_class, math_operation = MathOperation.multiply_add \ + , element_scale_factor = None ): self.instruction_shape = instruction_shape @@ -887,6 +990,8 @@ def __init__(self, self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation + self.element_scale_factor = element_scale_factor + # class TileDescription: diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index 78f6b887ef..ad81db6d0d 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -522,6 +522,7 @@ def __init__(self, args = None): arch_conditional_cc = [ '90a', + '100a' ] architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures] diff --git a/python/setup_library.py b/python/setup_library.py index 40f9283681..003111f307 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='cutlass_library', - version='3.7.0', + version='3.8.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 2b9cd02e61..822dfe16a1 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='pycute', - version='3.7.0', + version='3.8.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt index 6e8eeb7fbe..2f82812f2c 100644 --- a/test/self_contained_includes/CMakeLists.txt +++ b/test/self_contained_includes/CMakeLists.txt @@ -113,6 +113,15 @@ set(header_files_to_check cute/arch/mma_sm90_gmma.hpp cute/arch/mma.hpp cute/arch/util.hpp + + cute/arch/cluster_sm100.hpp + cute/arch/copy_sm100.hpp + cute/arch/copy_sm100_tma.hpp + cute/arch/mma_sm100.hpp + cute/arch/mma_sm100_desc.hpp + cute/arch/mma_sm100_umma.hpp + # cute/arch/tmem_allocator_sm100.hpp + # cute/atom # cute/atom/copy_atom.hpp # cute/atom/copy_traits.hpp @@ -131,6 +140,10 @@ set(header_files_to_check cute/atom/mma_traits_sm80.hpp cute/atom/mma_traits_sm90.hpp cute/atom/mma_traits_sm90_gmma.hpp + + cute/atom/mma_traits_sm100.hpp + cute/atom/partitioner.hpp + # cutlass cutlass/aligned_buffer.h cutlass/array.h @@ -185,12 +198,20 @@ set(header_files_to_check cutlass/version.h cutlass/wmma_array.h cutlass/workspace.h + + cutlass/exmy_base.h + cutlass/float_subbyte.h + # cutlass/platform cutlass/platform/platform.h # cutlass/pipeline cutlass/pipeline/pipeline.hpp cutlass/pipeline/sm90_pipeline.hpp + + cutlass/pipeline/sm100_pipeline.hpp + + # cutlass/detail cutlass/detail/cluster.hpp cutlass/detail/collective.hpp @@ -199,6 +220,10 @@ set(header_files_to_check cutlass/detail/layout.hpp cutlass/detail/mainloop_fusion_helper_bgrada.hpp cutlass/detail/mma.hpp + + cutlass/detail/sm100_blockscaled_layout.hpp + + # cutlass/arch cutlass/arch/arch.h cutlass/arch/barrier.h diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 875924ed0f..676b111e19 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -118,6 +118,7 @@ void FilterArchitecture() { { "SM80*", 80, kMaxDevice}, { "SM89*", 89, 89}, { "SM90*", 90, 90}, + { "SM100*", 100, 100}, { 0, 0, false } }; diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 9bd727f613..e4c7478934 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -679,6 +679,11 @@ struct GetName { static constexpr char name[] = "float_e4m3_t"; }; +template <> +struct GetName { + static constexpr char name[] = "float_e5m2_t"; +}; + template <> struct GetName { static constexpr char name[] = "half_t"; @@ -724,13 +729,20 @@ using VectorConvertTypes = ::testing::Types< ResultSourcePair, ResultSourcePair, + ResultSourcePair, + ResultSourcePair, ResultSourcePair, ResultSourcePair, + ResultSourcePair, + ResultSourcePair, ResultSourcePair, ResultSourcePair, + ResultSourcePair, + ResultSourcePair, ResultSourcePair, ResultSourcePair, + ResultSourcePair, ResultSourcePair, ResultSourcePair, ResultSourcePair diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index fedee3854f..5fdda499b2 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -29,6 +29,10 @@ add_custom_target(cutlass_test_unit_gemm_device) add_custom_target(test_unit_gemm_device) + +add_subdirectory(sm100_blockscaled_tensorop_gemm) + + ################################################################################ function(cutlass_test_unit_gemm_device_add_deps NAME) @@ -433,12 +437,12 @@ cutlass_test_unit_gemm_device_add_executable( gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu + sm80_gemm_f64_f64_f64_tensor_op_f64.cu + # SM90 device level tests gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu - sm80_gemm_f64_f64_f64_tensor_op_f64.cu - gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -821,3 +825,147 @@ if (CUTLASS_NVCC_DEVICE_COMPILE) endif() + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_sm100_fp16_gemm + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f16_f16_f32_tensor_op_f32.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm100_stream_k + + sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_sm100_bf16_gemm + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu +) + + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_stride_batch_alpha_beta_sm100 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_runtime_datatype_sm100 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu + sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu + sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu + sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_16b_tensorop_sm100_ptr_array + + # 14 (9 + 5) unit tests + sm100_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu + sm100_gemm_bf16_bf16_bf16_tensor_op_f32_ptr_array.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_16b_tensorop_sm100_group_gemm + + sm100_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_16b_mixed_tensorop_sm100_ptr_array + + # 14 (9 + 5) unit tests + sm100_gemm_f16_f16_f32_tensor_op_f32_ptr_array.cu + sm100_gemm_f16_f16_f16_tensor_op_f16_ptr_array.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_32b_tensorop_sm100_ptr_array + + # 10 unit tests + sm100_gemm_f32_f32_f32_tensor_op_f32_ptr_array.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_32b_tensorop_sm100_group_gemm + + # 10 unit tests + sm100_gemm_f32_f32_f32_tensor_op_f32_group_gemm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_8b_tensorop_sm100_ptr_array + + # 12 unit tests + sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu + sm100_gemm_f8_f8_f8_tensor_op_f32_ptr_array.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_8b_tensorop_sm100_group_gemm + + # 8 unit tests + sm100_gemm_f8_f8_f8_tensor_op_f32_group_gemm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_mxf8_training_sm100_group_gemm + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_group_gemm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_mxf4xmxf8_sm100_group_gemm + + # 8 unit tests + sm100_gemm_mxf4_mxf8_mxf8_tensor_op_f32_group_gemm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_blockscaled_gemm_device_fp4_tensorop_sm100_ptr_array + + # 8 unit tests + sm100_gemm_f4_f4_f32_tensor_op_f32_ptr_array.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_blockscaled_gemm_device_fp4_tensorop_sm100_group_gemm_1 + + # 8 unit tests + sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu +) +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_blockscaled_gemm_device_fp6_tensorop_sm100_ptr_array + + # 8 unit tests + sm100_gemm_f6_f6_f32_tensor_op_f32_ptr_array.cu +) +endif() + + diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index a9db871594..bf1d11fed2 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -139,7 +139,7 @@ struct ElementComputeType { }; template -struct ElementComputeType> { +struct ElementComputeType>> { using Type = typename Gemm::EpilogueOutputOp::ElementCompute; }; @@ -149,10 +149,22 @@ struct ElementScalarType { }; template -struct ElementScalarType> { +struct ElementScalarType>> { using Type = typename Gemm::EpilogueOutputOp::ElementScalar; }; + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + template struct IsSfdEpi : cute::false_type {}; @@ -274,9 +286,26 @@ bool initialize_tensor( scope_max = 2; scope_min = 0; } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; scope_min = -1; + + } + } else{ scope_max = 4; @@ -491,11 +520,24 @@ struct HostCollectiveMainloop { Arguments to_args() { + + // Runtime datatype selection + if constexpr (not cute::is_same_v) { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b + }; + } + else { + Arguments arguments = { tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b }; return arguments; + } } auto to_host_args(ProblemShapeType problem_size) { @@ -513,9 +555,19 @@ struct HostCollectiveMainloop { auto B = make_tensor(make_iterator(tensor_B.host_data()), make_layout(make_shape(N, K, L), stride_b)); + + auto dummy_SFA = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto dummy_SFB = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + cutlass::reference::host::GettMainloopParams mainloop_params{}; mainloop_params.A = A; @@ -840,6 +892,213 @@ struct HostCollectiveMainloop::HostCollectiveMainloopSparse; }; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Blk_MN = typename Sm100BlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm100BlkScaledConfig::Blk_SF; + using SfAtom = typename Sm100BlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideB stride_b; + + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + tensor_A.sync_device(); + tensor_B.sync_device(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + + template struct HostCollectiveDefaultEpilogue { // fusion types are potentially void if the fusion is not supported @@ -1127,6 +1386,22 @@ struct HostCollectiveEpilogue { typename Gemm::EpilogueOutputOp>; static_assert(cute::is_base_of_v); + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + using ElementSFD = non_void_t; + using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig< + SFD_VectorSize + >; + using Blk_MN = typename Sm100BlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm100BlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom; + cutlass::HostTensor tensor_SFD; + cutlass::HostTensor reference_SFD; + using ElementCompute = typename FusionOp::ElementCompute; using ElementScalar = typename FusionOp::ElementScalar; using ElementBias = non_void_t; @@ -1412,6 +1687,22 @@ struct HostCollectiveEpilogue { } } + + if constexpr (IsBlockScaleSupported) { + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + return cutlass::make_Coord(m_blks * Blk_MN{} * L, n_blks * Blk_SF{}); + }(); + tensor_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D)); + reference_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false); + tensor_SFD.sync_device(); + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + return true; } @@ -1510,6 +1801,17 @@ struct HostCollectiveEpilogue { } } + + if constexpr (IsBlockScaleSupported) { + tensor_SFD.sync_host(); + bool passed_sf = equality_check(reference_SFD.host_view(), tensor_SFD.host_view()); + if(!passed_sf) { + std::cout<<"SF is incorrect"<, cutlass::plus , false /*PerColumnBias_*/ + , SfGenStrategy > epilogue_params{}; epilogue_params.C = C; @@ -1791,6 +2122,11 @@ struct HostCollectiveEpilogue { epilogue_params.Vbeta = Vbeta; } } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } return epilogue_params; } }; @@ -1801,6 +2137,8 @@ template < bool force_legacy_epilogue = false, typename ElementA = typename Gemm::GemmKernel::ElementA, typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* > struct TestbedImpl { // Kernel data types @@ -1822,6 +2160,20 @@ struct TestbedImpl { using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + uint32_t sm_count; // Used to force multi-wave tests for persistent kernel schedules constexpr static int MaxSmCount = 16; @@ -2007,6 +2359,8 @@ struct TestbedImpl { detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, DecompositionMode decomposition_mode = DecompositionMode::Heuristic + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} ) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) @@ -2073,6 +2427,13 @@ struct TestbedImpl { mainloop_args = collective_mma_inputs.to_args(); + + if constexpr (IsRuntimeDataType) { + mainloop_args.runtime_data_type_a = runtime_input_datatype_a; + mainloop_args.runtime_data_type_b = runtime_input_datatype_b; + } + + arguments = { cutlass::gemm::GemmUniversalMode::kGemm, @@ -2195,6 +2556,8 @@ template < bool force_legacy_epilogue = false, typename ElementA = typename Gemm::GemmKernel::ElementA, typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* > struct Testbed3x { @@ -2204,6 +2567,8 @@ struct Testbed3x { force_legacy_epilogue, ElementA, ElementB + , RuntimeDatatypeA + , RuntimeDatatypeB >; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; @@ -2244,10 +2609,13 @@ struct Testbed3x { DecompositionMode decomposition_mode = DecompositionMode::Heuristic, bool profiling = false, detail::Iterations iterations = detail::Iterations{} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} ) { return impl_.run( problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode + , runtime_input_datatype_a, runtime_input_datatype_b ); } }; @@ -2298,11 +2666,629 @@ bool TestGemmPerf3x(int iterations = 20) { return true; } + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +template < + typename Gemm, + typename RuntimeDataTypeA, + typename RuntimeDataTypeB, + bool force_legacy_epilogue = false> +bool TestRuntimeDataTypeSmall( + RuntimeDataTypeA runtime_input_datatype_a, + RuntimeDataTypeB runtime_input_datatype_b, + double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + + CtaShape_MNK cta_shape; + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment * MultiplierOffsetK, 512 + max_alignment * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + [[maybe_unused]] constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment; + + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + } + for (auto splits : problem_splits) { + + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + else { + std::cout << "Unsupported configuration for runtime datatype MXFP4." << std::endl; + return false; + } + } + + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + static_assert((cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v), + "Runtime datatype must be selected with an appropriate static umbrella data type."); + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // f6xf4 + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e3m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E3M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m3 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M3) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e5m2_e5m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e5m2 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e5m2_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for runtime datatype."); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment_k = 0; + int max_alignment_m = 0; + int max_alignment_n = 0; + + if constexpr (apply_alignment_offset) { + max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_n = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_m = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + } + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment_k * MultiplierOffsetK, 512 + max_alignment_k * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + std::vector raster_order_options = {RasterOrderOptions::Heuristic}; + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment_m; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment_n; + int l = test_batched_alpha_beta && wave == waves[0] && k == problem_size_k[0] ? 2 : 1; // only test the smallest problem size + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, l}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + for (RasterOrderOptions raster_order : raster_order_options) { + std::vector problem_splits = {detail::Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{4}); + } + } + for (auto splits : problem_splits) { + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, // raster_order + detail::MaxSwizzleSize(0), + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + EXPECT_TRUE(passed) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << l << " FAILED.\n"; + return false; + } + } // splits + } // raster_order + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + return TestSmall(alpha, + beta, + check_relative_equality, + use_device_scalars, + vector_scale_mode, + override_problem_size_k); +} + + + template < typename Gemm, template class ActivationFunctor = cutlass::epilogue::thread::Identity > -bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { +bool TestAll(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -2451,7 +3437,7 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative } template -bool TestAllBiasElementwise(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { +bool TestAllBiasElementwise(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { return TestAll(alpha, beta, check_relative_equality); } diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index db1114ba58..1ae073c4ac 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -111,6 +111,18 @@ struct ElementScalarType +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + // The maximum swizzle size to use // // This class, like Splits above makes it harder to confuse @@ -212,9 +224,26 @@ bool initialize_tensor( scope_max = 2; scope_min = 0; } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; scope_min = -1; + + } + } else{ scope_max = 4; @@ -487,6 +516,277 @@ struct HostCollectiveMainloop { } }; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::ElementSF; + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Blk_MN = typename Sm100BlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm100BlkScaledConfig::Blk_SF; + using SfAtom = typename Sm100BlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using InternalLayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using InternalLayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + std::vector stride_a_host; + std::vector stride_b_host; + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + std::vector layout_sfa_host; + std::vector layout_sfb_host; + cutlass::DeviceAllocation layout_sfa_device; + cutlass::DeviceAllocation layout_sfb_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + std::vector> tensors_SFA; + std::vector> tensors_SFB; + + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + cutlass::DeviceAllocation device_tensors_SFA; + cutlass::DeviceAllocation device_tensors_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + tensors_SFA.clear(); + tensors_SFB.clear(); + layout_sfa_host.clear(); + layout_sfb_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + + using namespace cute; + + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa_host.push_back(Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); + layout_sfb_host.push_back(Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); + + tensors_SFA.push_back(cutlass::HostTensor(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A))); + tensors_SFB.push_back(cutlass::HostTensor(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_SFA[i].host_view(), init_A, seed + 2024 + i)); + EXPECT_TRUE(initialize_tensor(tensors_SFB[i].host_view(), init_B, seed + 2025 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_SFA[i].host_view().at({0, 0}) = ElementSF(1); + tensors_SFB[i].host_view().at({0, 0}) = ElementSF(1); + + tensors_SFA[i].sync_device(); + tensors_SFB[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + std::vector ptr_SFA_host(L); + std::vector ptr_SFB_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + ptr_SFA_host.at(i) = tensors_SFA[i].device_data(); + ptr_SFB_host.at(i) = tensors_SFB[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + device_tensors_SFA.reset(L); + device_tensors_SFA.copy_from_host(ptr_SFA_host.data()); + + device_tensors_SFB.reset(L); + device_tensors_SFB.copy_from_host(ptr_SFB_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + layout_sfa_device.reset(problem_shapes.groups()); + layout_sfa_device.copy_from_host(layout_sfa_host.data()); + + layout_sfb_device.reset(problem_shapes.groups()); + layout_sfb_device.copy_from_host(layout_sfb_host.data()); + + if constexpr (IsGroupGemm) { + return Arguments{ + device_tensors_A.get(), stride_a_device.get(), + device_tensors_B.get(), stride_b_device.get(), + device_tensors_SFA.get(), layout_sfa_device.get(), + device_tensors_SFB.get(), layout_sfb_device.get() + }; + } + else { + return Arguments{ + device_tensors_A.get(), stride_a_host[0], + device_tensors_B.get(), stride_b_host[0], + device_tensors_SFA.get(), layout_sfa_host[0], + device_tensors_SFB.get(), layout_sfb_host[0] + }; + } + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto SfA = make_tensor(tensors_SFA[batch].host_data(), layout_sfa_host[batch]); + + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + auto SfB = make_tensor(tensors_SFB[batch].host_data(), layout_sfb_host[batch]); + + return cutlass::reference::host::GettMainloopParams + {A, SfA, B, SfB}; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view() + << "\nSFA =\n" << tensors_SFA[batch].host_view() + << "\nSFB =\n" << tensors_SFB[batch].host_view(); + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFA[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFB[batch].host_view()), 0); + return true; + } +}; + + template struct HostCollectiveDefaultEpilogue { // fusion types are potentially void if the fusion is not supported @@ -803,6 +1103,24 @@ struct HostCollectiveEpilogue { using FusionOp = typename Gemm::EpilogueOutputOp; static_assert(cute::is_base_of_v); + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + using ElementSFD = non_void_t, ElementD>; + using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig< + SFD_VectorSize + >; + using Blk_MN = typename Sm100BlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm100BlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom; + std::vector> tensors_SFD; + std::vector> references_SFD; + cutlass::DeviceAllocation device_tensors_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; using ElementScalar = typename FusionOp::ElementScalar; using ElementBias = non_void_t; @@ -904,6 +1222,11 @@ struct HostCollectiveEpilogue { references_D.clear(); stride_c_host.clear(); stride_d_host.clear(); + + tensors_SFD.clear(); + references_SFD.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = std::max(problem_shapes.groups(), L); @@ -1034,6 +1357,26 @@ struct HostCollectiveEpilogue { } } + + if constexpr (IsBlockScaleSupported) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + // If block scaled output is supported we always have at least 1 SFD + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + return cutlass::make_Coord(m_blks * Blk_MN{}, n_blks * Blk_SF{}); + }(); + tensors_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D))); + references_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false)); + tensors_SFD[i].sync_device(); + } + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + return true; } @@ -1116,6 +1459,17 @@ struct HostCollectiveEpilogue { passed &= tmp; } } + + if constexpr (IsBlockScaleSupported) { + tensors_SFD[batch].sync_host(); + bool passed_sf = equality_check(references_SFD[batch].host_view(), tensors_SFD[batch].host_view()); + if(!passed_sf) { + std::cout<<"SF is incorrect"< ptr_SFD_host(L); + for (int32_t i = 0; i < L; ++i) { + ptr_SFD_host.at(i) = tensors_SFD[i].device_data(); + } + device_tensors_SFD.reset(L); + device_tensors_SFD.copy_from_host(ptr_SFD_host.data()); + + arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + } return arguments; @@ -1341,6 +1708,20 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), + Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + + cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, @@ -1353,8 +1734,11 @@ struct HostCollectiveEpilogue { decltype(Valpha), decltype(Vbeta), ActivationFunctor + , decltype(SfD) + , Int , cutlass::plus , false + , SfGenStrategy > epilogue_params{}; epilogue_params.C = C; @@ -1397,6 +1781,12 @@ struct HostCollectiveEpilogue { epilogue_params.Vbeta = Vbeta; } } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + return epilogue_params; } }; @@ -1812,8 +2202,24 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0, using ElementB = typename Gemm::GemmKernel::ElementB; using TiledMma = typename Gemm::GemmKernel::TiledMma; int alignment_bits = 128; + + static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); + alignment_bits = cutlass::detail::get_input_alignment_bits(); + // For fp4 and fp6 mx kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_input = (alignment_bits / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits / cute::sizeof_bits::value); + + if constexpr (apply_alignment_offset) { + // If BlockScaled, then min alignment is SFVecSize + static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + if constexpr (IsBlockScaleSupported) { + alignment_input = cutlass::round_up(alignment_input, SFVecSize); + } + } + + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; CtaShape_MNK cta_shape; diff --git a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp index 4fc24ea4d6..8b00f98a97 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -258,6 +258,12 @@ struct Testbed3xTensorBroadcast { cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + + auto dummy_SFD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + using DummySFDVectorSize = cute::Int<0>; + + cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, @@ -270,6 +276,8 @@ struct Testbed3xTensorBroadcast { decltype(dummy_Valpha), decltype(dummy_Vbeta), ActivationFunctor, + decltype(dummy_SFD), + DummySFDVectorSize, cutlass::plus, PerColBias> epilogue_params{ alpha, diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt new file mode 100644 index 0000000000..01e79c9898 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt @@ -0,0 +1,150 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# + +# + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +add_custom_target( + cutlass_test_unit_gemm_device_sm100_blockscaled + DEPENDS + cutlass_test_unit_gemm_device_bstensorop_sm100_nvf4xnvf4 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf4 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf6 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf8 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf8 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf6 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf8 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf4 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf4 + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf6 +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_nvf4xnvf4 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + nvf4_nvf4_bf16_bf16.cu + nvf4_nvf4_bf16_bf16_features.cu + nvf4_nvf4_f16_nvfp4_epilogue.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf4 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf4_mxf4_void_f16_tn_layout.cu + mxf4_mxf4_void_f16_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf6 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf6_mxf6_void_bf16_tn_layout.cu + mxf6_mxf6_void_bf16_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf8 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf8_mxf8_void_f8_tn_layout.cu + mxf8_mxf8_void_f8_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf8 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf6_mxf8_void_f32_tn_layout.cu + mxf6_mxf8_void_f32_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf6 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf8_mxf6_f16_f8_tn_layout.cu + mxf8_mxf6_f16_f8_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf8 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf4_mxf8_bf16_bf16_tn_layout.cu + mxf4_mxf8_bf16_bf16_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf4 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf8_mxf4_f16_bf16_tn_layout.cu + mxf8_mxf4_f16_bf16_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf4 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf6_mxf4_f16_f16_tn_layout.cu + mxf6_mxf4_f16_f16_nt_layout.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf6 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + mxf4_mxf6_f32_f16_tn_layout.cu + mxf4_mxf6_f32_f16_nt_layout.cu +) + +endif() diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu new file mode 100644 index 0000000000..6783195466 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp4xmxfp4 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e2m1}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp4 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*)| NN | + |--------|---------------|----|----|-------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe2m1t_void_f16t_bstensorop_f32, 128x128x256_1x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy - underlying selection is KernelTmaWarpSpecialized1SmMxf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe2m1t_void_f16t_bstensorop_f32, 128x256x256_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe2m1t_void_f16t_bstensorop_f32, 256x256x256_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu new file mode 100644 index 0000000000..4fa2b750ee --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu @@ -0,0 +1,523 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp4xmxfp4 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * B tensor: + * Types: {e2m1}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp4 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*)| TT | NT | NN | + |--------|---------------|-------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_void_f16t_bstensorop_f32, 128x128x256_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy - underlying selection is KernelTmaWarpSpecialized1SmMxf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_void_f16t_bstensorop_f32, 128x192x256_1x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_void_f16t_bstensorop_f32, 128x256x256_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_void_f16t_bstensorop_f32, 256x128x256_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_void_f16t_bstensorop_f32, 256x192x256_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy - underlying selection is KernelTmaWarpSpecialized1SmMxf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_void_f16t_bstensorop_f32, 256x256x256_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu new file mode 100644 index 0000000000..3d4d1df33e --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp4xmxfp6 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp6 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*)| NN | + |--------|---------------|----|----|-------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe2m3t_f32_f16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = float; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe2m3t_f16_f16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe2m3t_f16_f16t_bstensorop_f32, 256x256x128_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu new file mode 100644 index 0000000000..ee5251f390 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp4xmxfp6 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * B tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp6 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*)| TT | NT | NN | + |--------|---------------|-------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m3n_f32_f16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = float; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe3m2n_f32_f16t_bstensorop_f32, 128x192x128_2x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = float; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m3n_f16_f16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe3m2n_f32_f16t_bstensorop_f32, 256x128x128_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = float; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe3m2n_f32_f16t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = float; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m3n_f16_f16t_bstensorop_f32, 256x256x128_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu new file mode 100644 index 0000000000..4740641e6b --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp4xmxfp8 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Row Major (T) + * Alignment: 16 elements + * Mma Tile Shapes supported: + For the A tensor (mxfp4 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*) | NN | + |--------|---------------|----|----|--------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | Y | Y | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | Y | Y | Y | + | 2SM | 256x192x128 | Y | Y | Y | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe5m2t_bf16_bf16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe4m3t_bf16_bf16t_bstensorop_f32, 128x192x128_1x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe5m2t_bf16_bf16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe4m3t_bf16_bf16t_bstensorop_f32, 256x128x128_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe4m3t_bf16_bf16t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1n_ue8m0xe5m2t_bf16_bf16t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu new file mode 100644 index 0000000000..4cf8a4e1a0 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp4xmxfp8 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Column Major (N) + * Alignment: 16 elements + * Mma Tile Shapes supported: + For the A tensor (mxfp4 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*) | TT | NT | NN | + |--------|---------------|--------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | Y | Y | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | Y | Y | Y | + | 2SM | 256x192x128 | Y | Y | Y | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe5m2n_bf16_bf16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe4m3n_bf16_bf16t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe5m2n_bf16_bf16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe4m3n_bf16_bf16t_bstensorop_f32, 256x128x128_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe4m3n_bf16_bf16t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe5m2n_bf16_bf16t_bstensorop_f32, 256x256x128_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float4_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu new file mode 100644 index 0000000000..2823d71d04 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp6xmxfp4 Block Scaled Gemm + + * A tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e2m1}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp6 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*)| NN | + |--------|---------------|----|----|-------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3n_ue8m0xe2m1t_f16_f16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe2m1t_f16_f16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe2m1t_f16_f16t_bstensorop_f32, 256x256x128_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu new file mode 100644 index 0000000000..8add2c26d1 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp6xmxfp4 Block Scaled Gemm + + * A tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * B tensor: + * Types: {e2m1}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp6 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*)| TT | NT | NN | + |--------|---------------|-------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe2m1n_f16_f16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe3m1n_f16_f16t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe2m1n_f16_f16t_bstensorop_f32, 128x256x128_1x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe3m1n_f16_f16t_bstensorop_f32, 256x128x128_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe3m1n_f16_f16t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe2m1n_f16_f16t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu new file mode 100644 index 0000000000..b5f0722b51 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp6xmxfp6 Block Scaled Gemm + + * A tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp4 and mxfp6 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*)| NN | + |--------|---------------|----|----|-------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3n_ue8m0xe2m3t_void_bf16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe2m3t_void_bf16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe2m3t_void_bf16t_bstensorop_f32, 256x256x128_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu new file mode 100644 index 0000000000..115353b82a --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp6xmxfp6 Block Scaled Gemm + + * A tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * B tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * Mma Tile Shapes supported depends on the layout for mxfp6 mixed precision GEMM + The tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*)| TT | NT | NN | + |--------|---------------|-------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe2m3n_void_bf16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe3m2n_void_bf16t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe2m3n_void_bf16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe3m2n_void_bf16t_bstensorop_f32, 256x128x128_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe3m2n_void_bf16t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe2m3n_void_bf16t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu new file mode 100644 index 0000000000..73bf37ce38 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu @@ -0,0 +1,523 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp6xmxfp8 Block Scaled Gemm + + * A tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Row Major (T) + * Alignment: 16 elements + * Mma Tile Shapes supported: + For the A tensor (mxfp6 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*) | NN | + |--------|---------------|----|----|--------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | Y | Y | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | Y | Y | Y | + | 2SM | 256x192x128 | Y | Y | Y | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3n_ue8m0xe5m2t_void_f32t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3n_ue8m0xe4m3t_void_f32t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe5m2t_void_f32t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe4m3t_void_f32t_bstensorop_f32, 256x128x128_4x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3n_ue8m0xe4m3t_void_f32t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2n_ue8m0xe5m2t_void_f32t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu new file mode 100644 index 0000000000..9fb1afdd87 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp6xmxfp8 Block Scaled Gemm + + * A tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Column Major (N) + * Alignment: 16 elements + * Mma Tile Shapes supported: + For the A tensor (mxfp6 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*) | TT | NT | NN | + |--------|---------------|--------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | Y | Y | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | Y | Y | Y | + | 2SM | 256x192x128 | Y | Y | Y | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe5m2n_void_f32t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe4m3n_void_f32t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe5m2n_void_f32t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe4m3n_void_f32t_bstensorop_f32, 256x128x128_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe3m2t_ue8m0xe4m3n_void_f32t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe2m3t_ue8m0xe5m2n_void_f32t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float6_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu new file mode 100644 index 0000000000..fab68dc5ea --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp8xmxfp4 Block Scaled Gemm + + * A tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Column Major (N) + * Alignment: 16 elements + * B tensor: + * Types: {e2m1}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * Mma Tile Shapes supported: + For the B tensor (mxfp4 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*) | NN | + |--------|---------------|----|----|--------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3n_ue8m0xe2m1t_f16_bf16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe2m1t_f16_bf16t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe2m1t_f16_bf16t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu new file mode 100644 index 0000000000..f733d47f85 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu @@ -0,0 +1,523 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp8xmxfp4 Block Scaled Gemm + + * A tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Row Major (T) + * Alignment: 16 elements + * B tensor: + * Types: {e2m1}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * Mma Tile Shapes supported: + For the B tensor (mxfp4 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*) | TT | NT | NN | + |--------|---------------|--------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe2m1n_f16_bf16t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe2m1n_f16_bf16t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe2m1n_f16_bf16t_bstensorop_f32, 128x256x128_4x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe2m1n_f16_bf16t_bstensorop_f32, 256x128x128_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe2m1n_f16_bf16t_bstensorop_f32, 256x192x128_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe2m1n_f16_bf16t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float4_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu new file mode 100644 index 0000000000..34468a6082 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp8xmxfp6 Block Scaled Gemm + + * A tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Column Major (N) + * Alignment: 16 elements + * B tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Row Major (T) + * Alignment: 128 elements + * Mma Tile Shapes supported: + For the B tensor (mxfp6 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*) | NN | + |--------|---------------|----|----|--------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3n_ue8m0xe2m3t_f16_f8t_bstensorop_f32, 128x128x128_1x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe2m3t_f16_f8t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe2m3t_f16_f8t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu new file mode 100644 index 0000000000..33e36aa436 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp8xmxfp6 Block Scaled Gemm + + * A tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Row Major (T) + * Alignment: 16 elements + * B tensor: + * Types: {e2m3,e3m2}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * Mma Tile Shapes supported: + For the B tensor (mxfp6 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*) | TT | NT | NN | + |--------|---------------|--------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | N | N | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | N | N | Y | + | 2SM | 256x192x128 | Y | N | N | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe2m3n_f16_f8t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe3m2n_f16_f8t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe2m3n_f16_f8t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe3m2n_f16_f8t_bstensorop_f32, 256x128x128_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe3m2n_f16_f8t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe2m3n_f16_f8t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float6_t; + constexpr int AlignB = 128; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu new file mode 100644 index 0000000000..965de2c2c2 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu @@ -0,0 +1,523 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp8xmxfp8 Block Scaled Gemm + + * A tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Column Major (N) + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Row Major (T) + * Alignment: 16 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN | TT | NT (*) | NN | + |--------|---------------|----|----|--------|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | Y | Y | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | Y | Y | Y | + | 2SM | 256x192x128 | Y | Y | Y | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe5m2t_void_f8t_bstensorop_f32, 128x128x128_1x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe4m3t_void_f8t_bstensorop_f32, 128x192x128_1x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3n_ue8m0xe5m2t_void_f8t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3n_ue8m0xe4m3t_void_f8t_bstensorop_f32, 256x128x128_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2n_ue8m0xe4m3t_void_f8t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3n_ue8m0xe5m2t_void_f8t_bstensorop_f32, 256x256x128_4x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu new file mode 100644 index 0000000000..91a9e1004a --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for mxfp8xmxfp8 Block Scaled Gemm + + * A tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Row Major (T) + * Alignment: 16 elements + * B tensor: + * Types: {e5m2,e4m3}xue8m0 + * Layout: Column Major (N) + * Alignment: 16 elements + * Mma Tile Shapes supported: + For the A tensor (mxfp6 type) the tile dimension with stride-1 should be divisible by 128, i.e., 128 element aligned. + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*) | TT | NT | NN | + |--------|---------------|--------|----|----|----| + | 1SM | 128x128x128 | Y | Y | Y | Y | + | 1SM | 128x192x128 | Y | Y | Y | Y | + | 1SM | 128x256x128 | Y | Y | Y | Y | + | 2SM | 256x128x128 | Y | Y | Y | Y | + | 2SM | 256x192x128 | Y | Y | Y | Y | + | 2SM | 256x256x128 | Y | Y | Y | Y | + + (*) Unit tests in this file +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe5m2n_void_f8t_bstensorop_f32, 128x128x128_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe4m3n_void_f8t_bstensorop_f32, 128x192x128_2x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe5m2n_void_f8t_bstensorop_f32, 128x256x128_4x2x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + // For N=256 using f32 and f16 consumes too much SMEM space for Epilogue. + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe4m3n_void_f8t_bstensorop_f32, 256x128x128_4x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe5m2t_ue8m0xe4m3n_void_f8t_bstensorop_f32, 256x192x128_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_ue8m0xe4m3t_ue8m0xe5m2n_void_f8t_bstensorop_f32, 256x256x128_2x1x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::mx_float8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::mx_float8_t; + constexpr int AlignB = 16; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 16; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e5m2_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu new file mode 100644 index 0000000000..3501f6ec14 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu @@ -0,0 +1,683 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for nvfp4xnvfp4 Block Scaled Gemm + + * A tensor: + * Types: {e2m1}xue4m3 + * Layout: Row Major (T) + * Alignment: 32 elements + * B tensor: + * Types: {e2m1}xue4m3 + * Layout: Column Major (N) + * Alignment: 32 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Size | TN (*) | TT | NT | NN | + |--------|---------------|--------|----|----|----| + | 1SM | 128x128x256 | Y | N | N | N | + | 1SM | 128x192x256 | Y | N | N | N | + | 1SM | 128x256x256 | Y | N | N | N | + | 2SM | 256x128x256 | Y | N | N | N | + | 2SM | 256x192x256 | Y | N | N | N | + | 2SM | 256x256x256 | Y | N | N | N | + + (*) Unit tests in this file +*/ +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 128x128x256_4x4x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 256x128x256_2x2x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +/////////////////////////////////////////////////////////////////////////////// +// +// Using targeted scheduling with **static** cluster shapes +// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 128x128x256_2x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 256x128x256_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +////////////////////////////////////////////////////////////////////////////// +// +// Using large Cta Tiles: N=192 and N=256 +// +////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 128x192x256_2x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 128x256x256_2x1x1_1sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 256x192x256_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 256x256x256_2x4x1_2sm_auto) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16_features.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16_features.cu new file mode 100644 index 0000000000..7efb49d379 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16_features.cu @@ -0,0 +1,374 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Runtime data type for blockscaled gemm fp4 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////// +// +// Using Runtime Types +// +////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 128x128x256_4x2x1_1sm_auto_runtime_dtypes) { + // Describe A and B tensors + using ElementA = cutlass::type_erased_dynamic_nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::type_erased_dynamic_nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestRuntimeDataTypeSmall(cute::UMMA::MXF4Format::E2M1, cute::UMMA::MXF4Format::E2M1); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 256x128x256_2x4x1_2sm_auto_runtime_dtypes) { + // Describe A and B tensors + using ElementA = cutlass::type_erased_dynamic_nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::type_erased_dynamic_nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestRuntimeDataTypeSmall(cute::UMMA::MXF4Format::E2M1, cute::UMMA::MXF4Format::E2M1); + // Check results + EXPECT_TRUE(pass); +} + +////////////////////////////////////////////////////////////////////////////// +// +// Using Stream-K Scheduler +// +////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 128x128x256_1x4x1_1sm_auto_streamK) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler // Specify the streamK scheduler for the kernel + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_NVue4m3xe2m1t_NVue4m3xe2m1n_bf16t_bf16t_bstensorop_f32, 256x128x256_2x2x1_2sm_auto_streamK) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::bfloat16_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler // Specify the streamK scheduler for the kernel + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_f16_nvfp4_epilogue.cu b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_f16_nvfp4_epilogue.cu new file mode 100644 index 0000000000..e4b2513500 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_f16_nvfp4_epilogue.cu @@ -0,0 +1,436 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit test for nvfp4 Block Scaled Gemm with nvfp4 output + D tensor: + * Types: e2m1x{ue4m3} + * Layout: Column Major (T) + * Alignment: 32 + * Scale factors need to be generated with the fp4 output. It is generated along the continuous dimensions of the D tensor. + * Meanwhile, before scale factor generation, it could have other epilogue fusion operation. + * alpha + * beta + * activation + * bias + This UT tests + - alpha + beta + scale-factor generation + - alpha + beta + bias + scale-factor generation +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////// +// FusionOperation: k-major output and datatype is float_e2m1_t with float_ue4m3_t scale-factor (vecsize 16) +// with alpha/beta fusion +////////////////////////////////////////////////////////////////////////////// +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_ue8m0xe2m1t_outputVs16_bstensorop_1sm_f32, 128x128x256_4x4x1) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e2m1_t; + constexpr int AlignD = 32; + using GmemLayoutD = cutlass::layout::RowMajor; + // Describe SFD tensor + using ElementSFD = cutlass::float_ue4m3_t; + using GmemLayoutSFD = GmemLayoutD; + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct FusionOperation + // + constexpr int SFDVectorSize = 16; + // Define the fusion operation applied during epilogue + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFDVectorSize, + ElementD, ElementCompute, + ElementSFD, GmemLayoutSFD, + ElementC + >; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_ue4m3xe2m1t_ue4m3xe2m1n_ue4m3xe2m1t_outputVs16_bstensorop_2sm_f32, 256x128x256_4x4x1) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e2m1_t; + constexpr int AlignD = 32; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Describe SFD tensor + using ElementSFD = cutlass::float_ue4m3_t; + using GmemLayoutSFD = GmemLayoutD; + + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct FusionOperation + // + constexpr int SFDVectorSize = 16; + // Define the fusion operation applied during epilogue + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFDVectorSize, + ElementD, ElementCompute, + ElementSFD, GmemLayoutSFD, + ElementC + >; + + // + // Construct CollectiveEpilogue + // + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100 + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +////////////////////////////////////////////////////////////////////////////// +// FusionOperation: k-major output and datatype is float_e2m1_t with float_ue4m3_t scale-factor (vecsize 32) +// with alpha/beta fusion +////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_ue8m0xe2m1t_outputVs32_bstensorop_1sm_f32, 128x128x256_4x4x1) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e2m1_t; + constexpr int AlignD = 32; + using GmemLayoutD = cutlass::layout::RowMajor; + // Describe SFD tensor + using ElementSFD = cutlass::float_ue4m3_t; + using GmemLayoutSFD = GmemLayoutD; + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // + // Construct FusionOperation + // + constexpr int SFDVectorSize = 32; + // Define the fusion operation applied during epilogue + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + SFDVectorSize, + ElementD, ElementCompute, + ElementSFD, GmemLayoutSFD, + ElementC + >; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +////////////////////////////////////////////////////////////////////////////// +// FusionOperation: k-major output and datatype is float_e2m1_t with float_ue4m3_t scale-factor (vecsize 16) +// with alpha+beta+relu+bias fusion +////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_ue8m0xe2m1t_ue8m0xe2m1n_ue8m0xe2m1n_outputVs16_bstensorop_1sm_f32_bias_relu, 128x128x256_4x4x1) { + // Describe A and B tensors + using ElementA = cutlass::nv_float4_t; + constexpr int AlignA = 32; + using GmemLayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::nv_float4_t; + constexpr int AlignB = 32; + using GmemLayoutB = cutlass::layout::ColumnMajor; + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e2m1_t; + constexpr int AlignD = 32; + using GmemLayoutD = cutlass::layout::RowMajor; + // Describe SFD tensor + using ElementSFD = cutlass::float_ue4m3_t; + using GmemLayoutSFD = GmemLayoutD; + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + // Bias type + using ElementBias = float; + + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_256>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_256>; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + constexpr int SFDVectorSize = 32; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasBlockScaleFactor< + SFDVectorSize, ElementD, ElementCompute, + ElementSFD, GmemLayoutSFD, + ElementBias, ElementC + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, AlignC, + ElementD, GmemLayoutC, AlignD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA, AlignA, + ElementB, GmemLayoutB, AlignB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestAll(); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_bf16_bf16_bf16_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_bf16_bf16_bf16_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..7928a697d0 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_bf16_bf16_bf16_tensor_op_f32_ptr_array.cu @@ -0,0 +1,364 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; + +TEST(SM100_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_2sm_f32_ptr_array, 256x128x64_4x1x1) { +// A matrix configuration +using ElementA = cutlass::bfloat16_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::bfloat16_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_4,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::bfloat16_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::bfloat16_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::bfloat16_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::bfloat16_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::bfloat16_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::bfloat16_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_bf16t_bf16n_f32n_tensor_op_2sm_f32_ptr_array, 256x256x64_4x4x1) { +// A matrix configuration +using ElementA = cutlass::bfloat16_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::bfloat16_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_4,_4,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu new file mode 100644 index 0000000000..5cd4158e2a --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu @@ -0,0 +1,323 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/// A Row B Col +TEST(SM100_Device_Gemm_f16t_f16n_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 16, + ElementD, GmemLayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, GmemLayoutA, 8, + ElementB, GmemLayoutB, 8, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Row +TEST(SM100_Device_Gemm_f16n_f16t_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 16, + ElementD, GmemLayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, GmemLayoutA, 8, + ElementB, GmemLayoutB, 8, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Row B Row +TEST(SM100_Device_Gemm_f16t_f16t_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 16, + ElementD, GmemLayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, GmemLayoutA, 8, + ElementB, GmemLayoutB, 8, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Col +TEST(SM100_Device_Gemm_f16n_f16n_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 16, + ElementD, GmemLayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, GmemLayoutA, 8, + ElementB, GmemLayoutB, 8, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_bf16t_bf16t_bf32_void_f32n_tensor_op, 128x256x64_1x2x1) { + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using MmaTileShape = Shape<_128,_128,_64>; + using TileShape_MNK = Shape<_128,_256,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + float, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f16_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f16_ptr_array.cu new file mode 100644 index 0000000000..08202e16ab --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f16_ptr_array.cu @@ -0,0 +1,364 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; + +TEST(SM100_Device_Gemm_f16t_f16t_f16n_f16n_tensor_op_1sm_f16_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16n_f16n_tensor_op_1sm_f16_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16n_f16n_tensor_op_1sm_f16_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16n_f16n_tensor_op_2sm_f16_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16n_f16n_tensor_op_2sm_f16_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..fb325fac06 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Grouped GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; + +TEST(SM100_Device_Gemm_f16t_f16n_f16n_tensor_op_1sm_f32_group, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16n_tensor_op_1sm_f32_group, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16n_tensor_op_2sm_f32_group, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16n_tensor_op_2sm_f32_group, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16n_f16t_f16t_tensor_op_1sm_f32_group, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16n_f16n_f16n_tensor_op_1sm_f32_group, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16t_f16t_f16t_tensor_op_1sm_f32_group, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16t_f16t_f16n_tensor_op_2sm_f32_group, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f16t_f16t_f16t_tensor_op_2sm_f32_group, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..8f1126f0d2 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu @@ -0,0 +1,665 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; +TEST(SM100_Device_Gemm_f16t_f16n_f16n_tensor_op_1sm_f32_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} +TEST(SM100_Device_Gemm_f16t_f16n_f16n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16n_f16n_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16n_f16n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16n_f16n_tensor_op_2sm_f32_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16n_f16t_f16t_tensor_op_1sm_f32_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16n_f16n_f16n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16t_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f16t_tensor_op_2sm_f32_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu new file mode 100644 index 0000000000..7547b757ce --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with stream-K scheduling +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; + +TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_stream_k, 128x256x64_1x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_256,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); + using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Testbed = Testbed3x; + bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_stream_k, 256x128x64_2x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); + using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Testbed = Testbed3x; + bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_stream_k, 256x256x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_256,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); + using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Testbed = Testbed3x; + bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); + EXPECT_TRUE(result); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_stream_k, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_256,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); + using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Testbed = Testbed3x; + bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu new file mode 100644 index 0000000000..ea7389a747 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_f16t_f16t_f32_void_f16n_tensor_op, 128x256x64_1x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_256,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); + using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..d71c1bbcd3 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32_ptr_array.cu @@ -0,0 +1,664 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; +TEST(SM100_Device_Gemm_f16t_f16n_f32n_tensor_op_1sm_f32_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} +TEST(SM100_Device_Gemm_f16t_f16n_f32n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16n_f32n_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16n_f32n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16n_f32t_tensor_op_2sm_f32_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_1sm_f32_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16n_f16t_f32n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16n_f16t_f32n_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f16n_f16n_f32n_tensor_op_2sm_f32_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..80a7c41383 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_group_gemm.cu @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Grouped GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; + +TEST(SM100_Device_Gemm_f32t_f32n_f32n_tensor_op_1sm_f32_group, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32n_tensor_op_1sm_f32_group, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32n_tensor_op_2sm_f32_group, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32n_f32t_tensor_op_2sm_f32_group, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + + +TEST(SM100Only_Device_Gemm_f32t_f32t_f32n_tensor_op_1sm_f32_group, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32n_f32n_f32n_tensor_op_1sm_f32_group, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32n_f32t_f32n_tensor_op_1sm_f32_group, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32t_f32t_f32n_tensor_op_2sm_f32_group, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100Only_Device_Gemm_f32n_f32n_f32n_tensor_op_2sm_f32_group, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..66038626af --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f32_f32_f32_tensor_op_f32_ptr_array.cu @@ -0,0 +1,667 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +using namespace cute; + +TEST(SM100_Device_Gemm_f32t_f32n_f32n_tensor_op_1sm_f32_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 1.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32t_f32n_f32n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32t_f32n_f32n_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32t_f32n_f32n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32t_f32n_f32t_tensor_op_2sm_f32_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + + +TEST(SM100_Device_Gemm_f32t_f32t_f32n_tensor_op_1sm_f32_ptr_array, 64x128x64_1x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_64,_128,_64>; +using ClusterShape_MNK = Shape<_1,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32n_f32n_f32n_tensor_op_1sm_f32_ptr_array, 128x128x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_128,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32n_f32t_f32n_tensor_op_1sm_f32_ptr_array, 128x64x64_1x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_128,_64,_64>; +using ClusterShape_MNK = Shape<_1,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(3.0, 2.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32t_f32t_f32n_tensor_op_2sm_f32_ptr_array, 256x128x64_2x1x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_128,_64>; +using ClusterShape_MNK = Shape<_2,_1,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM100_Device_Gemm_f32n_f32n_f32n_tensor_op_2sm_f32_ptr_array, 256x256x64_2x2x1) { +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::ColumnMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C matrix configuration +using ElementC = float; // Element type for C matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// D matrix configuration +using ElementD = float; // Element type for D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape_MNK = Shape<_256,_256,_64>; +using ClusterShape_MNK = Shape<_2,_2,_1>; +using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{})); +using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); +using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestSmall(2.0, 2.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..bd15382564 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu @@ -0,0 +1,327 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_group, 512x256x256_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_256,_256>; + using ClusterShape = Shape<_4,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementC), + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_group, 256x384x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_384,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 4, + ElementD, LayoutC *, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_group, 256x512x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_512,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 4, + ElementD, LayoutC *, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_2sm_f32_group, 256x256x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_256,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 4, + ElementD, LayoutC *, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_2sm_f32_group, 512x768x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_768,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 4, + ElementD, LayoutC *, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..7c33f2daf3 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_ptr_array.cu @@ -0,0 +1,327 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_ptr_array, 512x256x256_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_256,_256>; + using ClusterShape = Shape<_4,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 32, + MmaTypePairB, LayoutB, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_ptr_array, 256x384x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_384,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 32, + MmaTypePairB, LayoutB, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_ptr_array, 256x512x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_512,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 32, + MmaTypePairB, LayoutB, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_2sm_f32_ptr_array, 256x256x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_256,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 32, + MmaTypePairB, LayoutB, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_2sm_f32_ptr_array, 512x768x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_768,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 32, + MmaTypePairB, LayoutB, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu new file mode 100644 index 0000000000..aaf6d62233 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_runtime_datatype, 512x512x128_4x4x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + float, + float, + float, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float4_t, cutlass::layout::RowMajor, 128, + cutlass::type_erased_dynamic_float4_t, cutlass::layout::ColumnMajor, 128, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E2M1, cute::UMMA::MXF8F6F4Format::E2M1); + EXPECT_TRUE(pass); + +} + + +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + float, + float, + float, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float4_t, cutlass::layout::RowMajor, 128, + cutlass::type_erased_dynamic_float4_t, cutlass::layout::ColumnMajor, 128, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E2M1, cute::UMMA::MXF8F6F4Format::E2M1); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..35bf67c3df --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_ptr_array.cu @@ -0,0 +1,486 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32t_tensorop_1sm_f32_ptr_array, 128x128x256_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_128,_128,_256>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32n_tensorop_1sm_f32_ptr_array, 256x512x256_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_512,_256>; + using ClusterShape = Shape<_2,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32t_tensorop_1sm_f32_ptr_array, 512x768x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_768,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32n_tensorop_1sm_f32_ptr_array, 512x1024x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_1024,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32n_tensorop_2sm_f32_ptr_array, 256x256x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_256,_256,_256>; + using ClusterShape = Shape<_2,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32t_tensorop_2sm_f32_ptr_array, 512x512x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_512,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32n_tensorop_2sm_f32_ptr_array, 512x768x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_768,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e2m3t_e2m3n_f32t_tensorop_2sm_f32_ptr_array, 512x1024x256_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e2m3_t; + using ElementB = cutlass::float_e2m3_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using ClusterTileShape = cute::Shape<_512,_1024,_256>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 4, + ElementD, LayoutC, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA, 128, + MmaTypePairB, LayoutB, 128, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu new file mode 100644 index 0000000000..a2f0971fa7 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e3m2t_e2m3n_f32t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + float, + float, + float, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float6_t, cutlass::layout::RowMajor, 128, + cutlass::type_erased_dynamic_float6_t, cutlass::layout::ColumnMajor, 128, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E3M2, cute::UMMA::MXF8F6F4Format::E2M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e3m2t_e2m3n_f32t_tensorop_1sm_f32_runtime_datatype, 512x512x128_4x4x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + float, + float, + float, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float6_t, cutlass::layout::RowMajor, 128, + cutlass::type_erased_dynamic_float6_t, cutlass::layout::ColumnMajor, 128, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E3M2, cute::UMMA::MXF8F6F4Format::E2M3); + EXPECT_TRUE(pass); + +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu new file mode 100644 index 0000000000..bdd342fe60 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e4m3t_e2m1n_f32t_tensorop_2sm_f32_runtime_datatype, 256x128x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized2Sm, + + cutlass::epilogue::fusion::LinearCombination< + float, + float, + float, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, + cutlass::type_erased_dynamic_float4_t, cutlass::layout::ColumnMajor, 128, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E4M3, cute::UMMA::MXF8F6F4Format::E2M1); + EXPECT_TRUE(pass); + +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..5183c937c9 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_group_gemm.cu @@ -0,0 +1,504 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide Grouped GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_group, 64x128x128_1x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_64,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_2sm_f32_group, 256x128x128_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_256,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_512,_512,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_e4m3n_tensorop_1sm_f32_group, 128x128x128_1x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_e4m3n_tensorop_1sm_f32_group, 64x128x128_1x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_64,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3t_e4m3n_tensorop_2sm_f32_group, 256x128x128_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_256,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3t_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_512,_512,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3t_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1_silu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_512,_512,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementD), + EpilogueSchedule, + cutlass::epilogue::fusion::LinCombEltAct + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(2.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3t_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1_voidC_silu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ClusterTileShape = cute::Shape<_512,_512,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + void, LayoutD *, 16 / sizeof(ElementD), + ElementD, LayoutD *, 16 / sizeof(ElementD), + EpilogueSchedule, + cutlass::epilogue::fusion::LinCombEltAct + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA *, 16 / sizeof(ElementA), + ElementB, LayoutB *, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(2.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..6969d00162 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_ptr_array.cu @@ -0,0 +1,465 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////// 128x128x128 ////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_ptr_array, 128x128x128_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_ptr_array, 64x128x128_1x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_64,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_2sm_f32_ptr_array, 256x128x128_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_256,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_2sm_f32_ptr_array, 512x512x128_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_512,_512,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3n_e4m3t_e4m3n_tensorop_1sm_f32_ptr_array, 128x128x128_1x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3n_e4m3n_e4m3n_tensorop_1sm_f32_ptr_array, 64x128x128_1x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_64,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_2,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3t_e4m3n_tensorop_2sm_f32_ptr_array, 256x128x128_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_256,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3t_e4m3t_tensorop_2sm_f32_ptr_array, 512x512x128_4x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_512,_512,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu new file mode 100644 index 0000000000..74791e83e8 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu @@ -0,0 +1,297 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_2sm_f32_runtime_datatype, 256x128x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::epilogue::TmaWarpSpecialized2Sm, + + cutlass::epilogue::fusion::LinearCombination< + cutlass::float_e4m3_t, + float, + cutlass::float_e4m3_t, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + cutlass::float_e4m3_t, + float, + cutlass::float_e4m3_t, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e4m3t_e5m2n_e4m3t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + cutlass::float_e4m3_t, + float, + cutlass::float_e4m3_t, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E4M3, cute::UMMA::MXF8F6F4Format::E5M2); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + cutlass::float_e4m3_t, + float, + cutlass::float_e4m3_t, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E4M3, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e5m2t_e5m2n_e5m2t_tensorop_2sm_f32_runtime_datatype, 256x256x128_2x2x1) { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, cutlass::layout::RowMajor, 16, + cutlass::float_e5m2_t, cutlass::layout::RowMajor, 16, + cutlass::epilogue::TmaWarpSpecialized1Sm, + + cutlass::epilogue::fusion::LinearCombination< + cutlass::float_e5m2_t, + float, + cutlass::float_e5m2_t, + float + > + + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E5M2); + EXPECT_TRUE(pass); + +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu new file mode 100644 index 0000000000..187d820cd4 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////// Test Batch alpha and beta ////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1cta_s32_batch_alpha_beta, 128x64x128_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + + using FusionOperation = cutlass::epilogue::fusion::LinearCombination< + ElementD, + ElementCompute, + ElementC, + ElementBias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 1.0); // beta is [1.0, 2.0] + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_bias_relu_batch_alpha_beta, 128x128x128_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); // beta is [0.5, 1.5] + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_bias_relu__batch_alpha_beta0, 128x128x128_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, -1.0); // beta is [-1.0, 0.0] + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu new file mode 100644 index 0000000000..0b18aea2d3 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////// 128x64x128 Cluster1x1x1 TMEM 4x1 //////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_s8t_s8n_s8n_tensorop_1cta_s32_ptr_array, 128x64x128_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = int8_t; + using ElementB = int8_t; + using ElementC = int8_t; + using ElementD = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementBias = int8_t; + using ClusterTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = TestSmall(2, 0.5, CheckEquality::EXACT); + EXPECT_TRUE(pass); +} +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////// 128x64x128 Cluster4x2x1 TMEM 4x1 //////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_s8t_s8n_s8n_tensorop_1cta_s32_ptr_array, 512x128x128_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = int8_t; + using ElementB = int8_t; + using ElementC = int8_t; + using ElementD = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementBias = int8_t; + using ClusterTileShape = Shape<_512,_128,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_4,_2,_1>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = TestSmall(2, 0.5, CheckEquality::EXACT); + EXPECT_TRUE(pass); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////// 64x256x128 Cluster1x1x1 TMEM 4x1 //////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_s8t_s8n_s32n_tensorop_1cta_s32_ptr_array, 64x256x128_1x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = int8_t; + using ElementB = int8_t; + using ElementC = int32_t; + using ElementD = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + using ElementBias = int32_t; + using ClusterTileShape = cute::Shape<_64,_256,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_1,_1,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); + + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = TestSmall(2, 0.5, CheckEquality::EXACT); + EXPECT_TRUE(pass); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////// 64x256x128 Cluster2x4x1 TMEM 2x2 //////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100_Device_Gemm_s8t_s8n_s8n_tensorop_2cta_s32_ptr_array, 128x1024x128_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = int8_t; + using ElementB = int8_t; + using ElementC = int8_t; + using ElementD = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementBias = int8_t; + using ClusterTileShape = Shape<_128,_1024,Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2,_4,_1>; + using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); + + using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementD, LayoutC, 16 / sizeof(ElementD), + EpilogueSchedule + >::CollectiveOp; + + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16 / sizeof(ElementA), + ElementB, LayoutB, 16 / sizeof(ElementB), + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = TestSmall(2, 0.5, CheckEquality::EXACT); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + diff --git a/test/unit/gemm/device/sm100_gemm_mxf4_mxf8_mxf8_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_mxf4_mxf8_mxf8_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..a6bbcfce55 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_mxf4_mxf8_mxf8_tensor_op_f32_group_gemm.cu @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/// A Row B Col +TEST(SM100Only_Device_Gemm_e2m1t_e4m3n_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::float_e4m3_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 128, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestAll(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Row +TEST(SM100Only_Device_Gemm_e2m1n_e4m3t_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::float_e4m3_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 128, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestAll(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Row B Row +TEST(SM100Only_Device_Gemm_e2m1t_e4m3t_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::float_e4m3_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 128, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestAll(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Col +TEST(SM100Only_Device_Gemm_e2m1n_e4m3n_e4m3t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementD = cutlass::float_e4m3_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 128, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestAll(1.0, 2.0); + EXPECT_TRUE(pass); +} +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu b/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu new file mode 100644 index 0000000000..0ee4c2bba0 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/// A Row B Col +TEST(SM100_Device_Gemm_e4m3t_e4m3n_f32t_tensorop_2sm_f32_auto, 512x512x128_4x4x1) { + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 4, + ElementD, GmemLayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA, 16, + ElementB, GmemLayoutB, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Row +TEST(SM100_Device_Gemm_e4m3n_e4m3t_f32t_tensorop_2sm_f32_auto, 512x512x128_4x4x1) { + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 4, + ElementD, GmemLayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA, 16, + ElementB, GmemLayoutB, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Row B Row +TEST(SM100_Device_Gemm_e4m3t_e4m3t_f32t_tensorop_2sm_f32_auto, 512x512x128_4x4x1) { + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 4, + ElementD, GmemLayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA, 16, + ElementB, GmemLayoutB, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Col +TEST(SM100_Device_Gemm_e4m3n_e4m3n_f32t_tensorop_2sm_f32_auto, 512x512x128_4x4x1) { + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC, 4, + ElementD, GmemLayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA, 16, + ElementB, GmemLayoutB, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0); + EXPECT_TRUE(pass); +} +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..c671995437 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_group_gemm.cu @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/// A Row B Col +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_f32t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 16, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Row +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_f32t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 16, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Row B Row +TEST(SM100Only_Device_Gemm_e4m3t_e4m3t_f32t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::RowMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 16, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0); + EXPECT_TRUE(pass); +} + +/// A Col B Col +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_f32t_tensorop_2sm_f32_group, 512x512x128_4x4x1) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = float; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = decltype(cute::make_tuple(ElementA{}, ElementSF{})); + using MmaTypePairB = decltype(cute::make_tuple(ElementB{}, ElementSF{})); + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + using ClusterTileShape_MNK = Shape<_512,_512,_128>; + using ClusterShape_MNK = Shape<_4,_4,_1>; + using MmaTileShape_MNK = Shape<_256,_128,_128>; + using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + OutputCtaShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, 16, + ElementD, GmemLayoutC *, 16, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, GmemLayoutA *, 16, + MmaTypePairB, GmemLayoutB *, 16, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0); + EXPECT_TRUE(pass); +} +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu b/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu index 58a5f86334..2bcf83cd2a 100644 --- a/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu +++ b/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu @@ -87,6 +87,7 @@ TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align1, 6 ///////////////////////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { using ElementOutput = float; @@ -124,6 +125,8 @@ TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 1 EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); } +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/threadblock/mma_multistage.cu b/test/unit/gemm/threadblock/mma_multistage.cu index cee23cb267..21ecacdc37 100644 --- a/test/unit/gemm/threadblock/mma_multistage.cu +++ b/test/unit/gemm/threadblock/mma_multistage.cu @@ -2974,6 +2974,7 @@ TEST(SM80_gemm_threadblock_crosswise, } //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x1024_64x64x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3006,8 +3007,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x1024_32x32x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3040,8 +3044,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x64x1024_64x32x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3074,8 +3081,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x1024x1024_32x64x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3108,8 +3118,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x1024x1024_64x64x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3142,8 +3155,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, multicta_256x256x6144_128x1024x1024_64x64x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3176,8 +3192,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, multicta_512x256x6144_256x1024x1024_64x64x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3210,8 +3229,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x512_64x64x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3244,8 +3266,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x512_32x32x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3278,8 +3303,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x64x512_64x32x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3312,8 +3340,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x128x512_32x64x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3346,8 +3377,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x128x512_64x64x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3380,8 +3414,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, multicta_256x256x6144_128x128x512_64x64x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3414,8 +3451,11 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// +#if 0 TEST(SM80_gemm_threadblock_crosswise, multicta_512x256x6144_256x128x512_64x64x512_16x8x256_4stage) { using ElementA = cutlass::uint1b_t; @@ -3448,6 +3488,8 @@ TEST(SM80_gemm_threadblock_crosswise, problem_size.k(), alpha, beta) .run(grid, block); } +#endif + //////////////////////////////////////////////////////////////////////////////// TEST(SM80_gemm_threadblock_congruous, tensor_op_64x64x16_32x64x16_8x8x4_3stage) { diff --git a/test/unit/pipeline/CMakeLists.txt b/test/unit/pipeline/CMakeLists.txt index 81bdbf3214..230734a1e0 100644 --- a/test/unit/pipeline/CMakeLists.txt +++ b/test/unit/pipeline/CMakeLists.txt @@ -31,6 +31,7 @@ cutlass_test_unit_add_executable( pipeline_tma_async.cu pipeline_tma_async_warp_specialized.cu pipeline_tma_async_warp_specialized_persistent.cu + pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu pipeline_async.cu sequence_barrier.cu ) diff --git a/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu b/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu new file mode 100644 index 0000000000..ae2fa4e29e --- /dev/null +++ b/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu @@ -0,0 +1,381 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the PipelineCLCFetchAsync class +*/ + +// + +// + +#define KERNEL_DBG_TRACE false + +#include +#include "../common/cutlass_unit_test.h" +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/core_io.h" +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include "testbed_cluster_launch_control.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" + + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm::kernel::detail; + +//////////////////// Shared Memory ///////////////////////// + +template +struct SharedStorage +{ + alignas(16) typename PersistentTileSchedulerSm100::CLCResponse clc_response[Stages]; + alignas(8) typename PersistentTileSchedulerSm100::PipelineStorage storage ; +}; + +//////////////////// Kernel ///////////////////////// +template +__launch_bounds__(256, 1) +__global__ static +void pipeline_device(int *d_workerCount) +{ + extern __shared__ char shared_memory[]; + + // single producer, multiple consumers + // producer: WG0 + // consumer: WG1 + + using SharedStorage = SharedStorage; + using Scheduler = PersistentTileSchedulerSm100; + using TileSchedulingPipeline = typename Scheduler::Pipeline; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Logistics + int warp_idx = canonical_warp_idx(); + auto cluster_shape = ClusterShape{}; + + typename TileSchedulingPipeline::Params params; + params.transaction_bytes = 16; + + constexpr int NUM_PRODUCER = 32; + constexpr int NUM_CONSUMERS_PER_CTA = 32; + params.consumer_arv_count = NUM_PRODUCER + NUM_CONSUMERS_PER_CTA * cute::size<0>(cluster_shape) * cute::size<1>(cluster_shape); + params.producer_arv_count = 1; + // Only the first CTA in the Cluster is producing. + params.producer_blockid = 0; + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + // mbarrier.init + TileSchedulingPipeline scheduler_pipeline(shared_storage.storage, params ); + Scheduler scheduler(&shared_storage.clc_response[0], typename Scheduler::Params{}, block_id_in_cluster); + + // Ensure All CTAs in Cluster have completed init before issuing commits + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + + uint32_t is_first_block_in_cluster = block_id_in_cluster.x == 0 && block_id_in_cluster.y == 0; + int lane_predicate = cute::elect_one_sync(); + + uint32_t is_producer = (is_first_block_in_cluster && warp_idx == 0); + uint32_t is_consumer = (warp_idx == 4); + + PipelineState scheduler_pipe_state; + PipelineState scheduler_pipe_state_write = cutlass::make_producer_start_state(); + typename Scheduler::WorkTileInfo work_tile_info = { + static_cast(blockIdx.x), + static_cast(blockIdx.y), + static_cast(blockIdx.z), + false + }; + + // Persistent loop + do { + // Producer + if (is_producer) { + // Only 1 thread of the entire cluster issues the query. + scheduler_pipe_state_write = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_state_write); + } + + // Consumers + if (is_consumer) { + int linearCLC = work_tile_info.N_idx * gridDim.x + work_tile_info.M_idx; + // Atomically increment the worker count for the linearCLC by 1. + if (lane_predicate) { + atomicAdd(&d_workerCount[linearCLC], 1); + } + } + + // Union of all consumers. Note that the producer here is its own consumer. + if (is_producer || is_consumer) { + scheduler_pipeline.consumer_wait(scheduler_pipe_state); + work_tile_info = scheduler.get_current_work(scheduler_pipe_state); + scheduler_pipeline.consumer_release(scheduler_pipe_state); + ++scheduler_pipe_state; + + // Add block offset since the scheduler works at cluster level. + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + work_tile_info.M_idx += block_id_in_cluster.x; + work_tile_info.N_idx += block_id_in_cluster.y; + work_tile_info.L_idx += block_id_in_cluster.z; + + } + } while (work_tile_info.is_valid_tile); + + // End of kernel + cute::cluster_sync(); +} +///////////////////////////////////////////////////// + +template +struct PipelineTest { + + // + // Data members + // + static constexpr uint32_t Stages = Stages_; + static constexpr uint32_t BlockSize = 128 * 2; + using ClusterShape = ClusterShape_; + + // + // Methods + // + + bool check_results(int *h_workerCount, int size ) { + for (int i = 0 ; i< size; i++ ){ + if ( h_workerCount[i] != 1 ) + { + std::cout << "linearCLC " << i << " has worker count " << h_workerCount[i] << "\n"; + return false; + } + } + return true; + } + + // Run CuTe GEMM kernel + cudaError_t run(bool &success, dim3 grid_dim, + cudaStream_t stream = 0 ) { + + // + // Configure and launch + // + cudaError_t result; + + int smem_size = 192 * 1024; // 192kB to force 1CTA/SM + auto cluster_shape = Shape, Int, _1>{}; + // Launch a single Cluster, with BlockSize threads per CTA + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimGrid = grid_dim; + dim3 dimBlock(BlockSize,1,1); + + result = cudaFuncSetAttribute( + pipeline_device< + decltype(cluster_shape), + Stages>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + ); + + if (result != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return result; + } + + int array_size = dimGrid.x * dimGrid.y; + int *d_workerCount, *h_workerCount; + + /* Allocate memory. workerCount[i] counts the number of worker(s) which work + on linear t i. The expectation is that workerCount[i] == 1 for all i. + */ + h_workerCount = (int*)malloc(array_size * sizeof(int)); + + result = cudaMalloc(&d_workerCount, array_size * sizeof(int)); + if (result != cudaSuccess) { + std::cerr << "Failed to do cudaMalloc." << result << "\n"; + return result; + } + + for(int i = 0 ; i < array_size; i++) + { + h_workerCount[i] = 0; // Initialize workerCount[i] to 0 for all i. + } + + result = cudaMemcpy(d_workerCount, h_workerCount, array_size * sizeof(int), cudaMemcpyHostToDevice); + if (result != cudaSuccess) { + std::cerr << "Failed to do cudaMemcpy." << result << "\n"; + return result; + } + + // Extended launch API + const void* kernel = (const void*)pipeline_device; + void* kernel_params[] = {&d_workerCount}; + cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; + return result; + } + + result = cudaMemcpy(h_workerCount, d_workerCount, array_size * sizeof(int), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + std::cerr << "Failed to do cudaMemcpy." << result << "\n"; + return result; + } + + success = check_results(h_workerCount, array_size); + + free(h_workerCount); + + result = cudaFree(d_workerCount); + if (result != cudaSuccess) { + std::cerr << "Failed to do cudaFree." << result << "\n"; + return result; + } + + return cudaSuccess; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +//Cluster1x2 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster1x2_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster2x1 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster2x1_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster2x2 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster2x2_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster1x1 Stage3 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster1x1_Stage3) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr uint32_t Stages = 3; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster1x4 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster1x4_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster4x1 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster4x1_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster2x4 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster2x4_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster4x2 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster4x2_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} + +//Cluster4x4 Stage4 +TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster4x4_Stage4) { + Options options; + options.grid_dim = {32,32,1}; + using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; + static constexpr uint32_t Stages = 4; + using Test = PipelineTest; + Testbed testbed(options); + EXPECT_TRUE(testbed.verification()); +} +#endif diff --git a/test/unit/pipeline/testbed_cluster_launch_control.h b/test/unit/pipeline/testbed_cluster_launch_control.h new file mode 100644 index 0000000000..4ac892de36 --- /dev/null +++ b/test/unit/pipeline/testbed_cluster_launch_control.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed file used by cluster launch control pipeline unit test +*/ + +// + +// + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +// Command line test options +struct Options { + // + // Data Members + // + bool help = false; + bool verification_enabled = true; + int SM_count = 116; + int clock_MHz = 1477; + dim3 grid_dim = {0,0,0}; + + // + // Methods + // + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, verification_enabled); + cmd.get_cmd_line_argument("sm-count", SM_count, SM_count); + cmd.get_cmd_line_argument("clock", clock_MHz, clock_MHz); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +class Testbed { +private: + // Commandline options + Options options; + + bool run_test() { + + // Run CuTe Gemm + Pipeline pipeline; + + bool success = false; + cudaError_t result = pipeline.run(success, this->options.grid_dim); + + CUTE_CHECK_LAST(); + return success; + } + + +public: + Testbed(Options const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + +#if !defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + printf( + "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be set, but it is not. \n" + "This test is waived.\n" + ); + return true; +#endif + +#if 1 + bool is_success = false; + for (int i = 0; i< 10; i++){ + printf("iteration = %d\n", i); + is_success = run_test(); + if ( not is_success ) + return is_success; + } + return is_success; +#else + // Run the test with single launch + return run_test(); +#endif + } +}; diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 2052dd2c08..bfe1b5f48c 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -221,6 +221,19 @@ cutlass_add_cutlass_library( # files split for parallel compilation src/reference/gemm_int4.cu + + src/reference/block_scaled_gemm_fp4a_vs16.cu + src/reference/block_scaled_gemm_fp4a_vs32.cu + src/reference/block_scaled_gemm_mixed8bitsa.cu + src/reference/gemm_f4_f4_f32.cu + src/reference/gemm_f4_f6_f32.cu + src/reference/gemm_f4_f8_f32.cu + src/reference/gemm_f6_f4_f32.cu + src/reference/gemm_f6_f6_f32.cu + src/reference/gemm_f6_f8_f32.cu + src/reference/gemm_f8_f4_f32.cu + src/reference/gemm_f8_f6_f32.cu + src/reference/gemm_s8_s8_s32.cu src/reference/gemm_u8_u8_s32.cu src/reference/gemm_int8_interleaved_32.cu diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h index eee0c78608..29f70d3a65 100644 --- a/tools/library/include/cutlass/library/arch_mappings.h +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -119,6 +119,18 @@ template <> struct ArchMap { static int const kMax = 90; }; + +template struct ArchMap { + static int const kMin = 100; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 100; + static int const kMax = 100; +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/include/cutlass/library/descriptions.h b/tools/library/include/cutlass/library/descriptions.h index ae96395f9a..fb56b027b9 100644 --- a/tools/library/include/cutlass/library/descriptions.h +++ b/tools/library/include/cutlass/library/descriptions.h @@ -300,6 +300,101 @@ struct GemmDescription : public OperationDescription { transform_B(transform_B) {} }; + +/// Description of all GEMM computations +struct BlockScaledGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + // + // Methods + // + + BlockScaledGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockScaledGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Description for structured sparse GEMMs. diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index bb37b1bc74..027944eb6a 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -178,6 +178,15 @@ class Handle { int M, /// GEMM M dimension int N, /// GEMM N dimension int K, /// GEMM K dimension + + int cluster_m, /// cluster shape M dimension + int cluster_n, /// cluster shape N dimension + int cluster_k, /// cluster shape K dimension + int cluster_m_fallback, /// Fallback cluster shape M dimension + int cluster_n_fallback, /// Fallback cluster shape N dimension + int cluster_k_fallback, /// Fallback cluster shape K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation NumericTypeID element_scalar, /// Data type of alpha/beta scalars diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index a4c6572e5f..0309ec3110 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -103,6 +103,7 @@ class Operation { void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; + // Originally designed for metadata, but should be useful for FP8/6/4 too. virtual Status initialize_with_profiler_workspace( void const *configuration, void *host_workspace, @@ -269,6 +270,8 @@ struct GemmUniversalConfiguration { GemmUniversalMode mode{GemmUniversalMode::kGemm}; gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; int batch_count{1}; int64_t lda{0}; @@ -282,6 +285,8 @@ struct GemmUniversalConfiguration { struct GemmUniversalArguments { // NOTE: these are replicated for 3.0 interfaces gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; int batch_count{1}; void const *A{nullptr}; @@ -307,13 +312,68 @@ struct GemmUniversalArguments { // Needed for some 3.x kernels int sm_count{0}; library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{}; + library::RuntimeDatatype runtime_input_datatype_b{}; int swizzle_size{1}; + int split_k_slices{1}; int device_index{0}; bool use_pdl{false}; }; + +/// Block Scaled GEMM +// +// OperationKind: kBlockScaledGemm +// GemmKind: Universal + +struct BlockScaledGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + void *SFD{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for ScaleFactor Generation + void const *norm_constant{nullptr}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Complex valued GEMM in which real and imaginary parts are separated by a stride diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h index 05b84b1e3d..6a8655ceaf 100644 --- a/tools/library/include/cutlass/library/operation_table.h +++ b/tools/library/include/cutlass/library/operation_table.h @@ -243,6 +243,191 @@ using GemmOperationFunctionalMap = std::unordered_map< GemmFunctionalKeyHasher >; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for BlockScaled Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockScaledGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + NumericTypeID element_SFD; + LayoutTypeID layout_SFD; + int SFVecSize; + int EpilogueSFVecSize; + // + // Methods + // + + inline + BlockScaledGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockScaledGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFD = NumericTypeID::kF16, + LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor, + int sf_vec_size = 32 + , int epilogue_sf_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + element_SFD(element_SFD), + layout_SFD(layout_SFD), + SFVecSize(sf_vec_size) + , EpilogueSFVecSize(epilogue_sf_vec_size) + { } + + inline + bool operator==(BlockScaledGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (element_SFD == rhs.element_SFD) && + (layout_SFD == rhs.layout_SFD) && + (SFVecSize == rhs.SFVecSize) + && (EpilogueSFVecSize == rhs.EpilogueSFVecSize) + ; + } + + inline + bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " element_SFD: " << to_string(k.element_SFD) << "\n" + << " layout_SFD: " << to_string(k.layout_SFD) << "\n" + << " SFVecSize: " << k.SFVecSize << "\n" + << "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockScaledGemmFunctionalKeyHasher +struct BlockScaledGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockScaledGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.element_SFD)), 16) ^ + rotl(hash(int(key.layout_SFD)), 17) ^ + rotl(hash(int(key.SFVecSize)), 18) ^ + rotl(hash(int(key.EpilogueSFVecSize)), 19) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockScaledGemmOperationFunctionalMap = std::unordered_map< + BlockScaledGemmFunctionalKey, + GemmOperationVectorMap, + BlockScaledGemmFunctionalKeyHasher +>; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Data Structures for Conv Functional Maps ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -509,6 +694,9 @@ class OperationTable { // provider (kCUTLASS) GemmOperationFunctionalMap gemm_operations; + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations; + /// Map of all operations of type kConv2d // provider (kCUTLASS, kReferenceHost, kReferenceDevice) ConvOperationFunctionalMap conv2d_operations; diff --git a/tools/library/include/cutlass/library/types.h b/tools/library/include/cutlass/library/types.h index 5685386347..4b0e36fe4e 100644 --- a/tools/library/include/cutlass/library/types.h +++ b/tools/library/include/cutlass/library/types.h @@ -43,6 +43,7 @@ enum class LayoutTypeID { kUnknown, kColumnMajor, kRowMajor, + kBlockScalingTensor, kColumnMajorInterleavedK2, kRowMajorInterleavedK2, kColumnMajorInterleavedK4, @@ -83,6 +84,16 @@ enum class NumericTypeID { kS64, kFE4M3, kFE5M2, + + kFE2M3, + kFE3M2, + kFE2M1, + kFUE8M0, + kFUE4M3, + kF8, + kF6, + kF4, + kF16, kBF16, kTF32, @@ -131,6 +142,7 @@ enum class Provider { /// Enumeration indicating the kind of operation enum class OperationKind { kGemm, + kBlockScaledGemm, kRankK, kRank2K, kTrmm, @@ -165,6 +177,7 @@ enum class OpcodeClassID { kTensorOp, kWmmaTensorOp, kSparseTensorOp, + kBlockScaledOp, kInvalid }; @@ -188,6 +201,7 @@ enum class MathOperationID { /// Enumeration indicating what kind of GEMM operation to perform enum class GemmKind { kGemm, + kBlockScaledGemm, kSparse, kUniversal, kPlanarComplex, @@ -251,6 +265,20 @@ enum class EpilogueKind { kInvalid }; + +enum class RuntimeDatatype { + kStatic, + kE4M3, + kE5M2, + + kE3M2, + kE2M3, + kE2M1, + + kInvalid +}; + + enum class RasterOrder { kAlongN, kAlongM, diff --git a/tools/library/include/cutlass/library/util.h b/tools/library/include/cutlass/library/util.h index af82ffbc57..bf763e15c3 100644 --- a/tools/library/include/cutlass/library/util.h +++ b/tools/library/include/cutlass/library/util.h @@ -170,6 +170,15 @@ char const *to_string(ConvKind type, bool pretty = false); template <> ConvKind from_string(std::string const &str); + +/// Converts a RuntimeDatatype enumerant to a string +char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false); + +/// Convers a RuntimeDatatype enumerant from a string +template<> +cutlass::library::RuntimeDatatype from_string(std::string const &str); + + /// Converts a RasterOrder enumerant to a string char const *to_string(RasterOrder type, bool pretty = false); @@ -202,6 +211,8 @@ bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t /// Casts from a real value represented as a double to the destination type. Returns true if successful. bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); +NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/src/block_scaled_gemm_operation_3x.hpp b/tools/library/src/block_scaled_gemm_operation_3x.hpp new file mode 100644 index 0000000000..b95f72ec0f --- /dev/null +++ b/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using Sm100BlkScaledConfig = typename CollectiveMainloop::Sm100BlkScaledConfig; + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; + + +private: + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockScaledGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = LayoutTypeID::kRowMajor; + description_.SFA.alignment = 128; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = LayoutTypeID::kRowMajor; + description_.SFB.alignment = 128; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFVecSize = SFVecSize; + + description_.SFD = make_TensorDescription(128); + description_.EpilogueSFVecSize = SFD_VectorSize; + + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockScaledGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + + operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b); + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockScaledGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index a089cb5da2..91f579bfd5 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -162,6 +162,18 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { using CollectiveEpilogue = typename Operator::CollectiveEpilogue; using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + public: /// Constructor @@ -235,8 +247,42 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { arguments->batch_count); // update arguments + + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments->A); operator_args.mainloop.ptr_B = static_cast(arguments->B); + } operator_args.epilogue.ptr_C = static_cast(arguments->C); operator_args.epilogue.ptr_D = static_cast(arguments->D); @@ -277,6 +323,22 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { } } + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; } diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index 82dc25d689..00ad2e0ec0 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -510,6 +510,15 @@ Status Handle::gemm_universal( int M, /// GEMM M dimension int N, /// GEMM N dimension int K, /// GEMM K dimension + + int cluster_m, /// cluster shape M dimension + int cluster_n, /// cluster shape N dimension + int cluster_k, /// cluster shape K dimension + int cluster_m_fallback, /// Fallback cluster shape M dimension + int cluster_n_fallback, /// Fallback cluster shape N dimension + int cluster_k_fallback, /// Fallback cluster shape K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation NumericTypeID element_scalar, /// Data type of alpha/beta scalars @@ -629,6 +638,8 @@ Status Handle::gemm_universal( GemmUniversalConfiguration configuration{ mode, {M, N, K}, + {cluster_m, cluster_n, cluster_k}, + {cluster_m_fallback, cluster_n_fallback, cluster_k_fallback}, batch_count, lda, ldb, @@ -647,6 +658,8 @@ Status Handle::gemm_universal( GemmUniversalArguments arguments{ {M, N, K}, + {cluster_m, cluster_n, cluster_k}, + {cluster_m_fallback, cluster_n_fallback, cluster_k_fallback}, batch_count, ptr_A, ptr_B, diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 8f4de51683..e8bd77397f 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -116,6 +116,27 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kFE5M2; }; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE3M2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M1; +}; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE8M0; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE4M3; +}; + + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kU16; }; @@ -161,6 +182,21 @@ template <> struct NumericTypeMap { }; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF6; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF4; +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct MathOperationMap { @@ -300,6 +336,12 @@ template <> struct OpcodeClassMap { static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; }; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp; +}; + + template <> struct OpcodeClassMap { static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; }; diff --git a/tools/library/src/operation_table.cu b/tools/library/src/operation_table.cu index 6719cd31ba..dd2b48c61e 100644 --- a/tools/library/src/operation_table.cu +++ b/tools/library/src/operation_table.cu @@ -48,6 +48,45 @@ void OperationTable::append(Manifest const &manifest) { // Insert operations into appropriate data structure for (auto const & operation : manifest) { OperationDescription const &desc = operation->description(); + + if (desc.kind == OperationKind::kBlockScaledGemm) { + BlockScaledGemmDescription const &gemm_desc = static_cast(desc); + + BlockScaledGemmFunctionalKey functional_key( + gemm_desc.provider, + gemm_desc.gemm_kind, + gemm_desc.kind, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.SFA.element, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.SFB.element, + gemm_desc.C.element, + gemm_desc.C.layout, + gemm_desc.D.element, + gemm_desc.D.layout, + gemm_desc.SFD.element, + gemm_desc.SFD.layout, + gemm_desc.SFVecSize + , gemm_desc.EpilogueSFVecSize + ); + + Operation const *op = operation.get(); + + int cc = gemm_desc.tile_description.minimum_compute_capability; + + int alignment = std::max(std::max( + gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); + + GemmPreferenceKey preference_key(cc, alignment); + + block_scaled_gemm_operations[functional_key][preference_key].push_back(op); + } + + // insert all gemm operation into operation table if (desc.kind == OperationKind::kGemm) { GemmDescription const &gemm_desc = static_cast(desc); diff --git a/tools/library/src/reference/block_scaled_gemm_fp4a_vs16.cu b/tools/library/src/reference/block_scaled_gemm_fp4a_vs16.cu new file mode 100644 index 0000000000..7afb6db591 --- /dev/null +++ b/tools/library/src/reference/block_scaled_gemm_fp4a_vs16.cu @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "block_scaled_gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest) { + + ////////////////////////////////////////////////////////////////////////////////////////////////////////// + // SFVectorSize = 16 with MxF4NvF4 instructions + ////////////////////////////////////////////////////////////////////////////////////////////////////////// + // (float_e2m1_t * float_ue4m3_t) * (float_e2m1_t * float_ue4m3_t) + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 16 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/, + 16 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue4m3_t /*SFA*/, float_e2m1_t /*B*/, float_ue4m3_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + // (float_e2m1_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t) + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 16 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 16 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/, + 16 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/, + 16 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm_tn< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 16 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/block_scaled_gemm_fp4a_vs32.cu b/tools/library/src/reference/block_scaled_gemm_fp4a_vs32.cu new file mode 100644 index 0000000000..b646d085de --- /dev/null +++ b/tools/library/src/reference/block_scaled_gemm_fp4a_vs32.cu @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "block_scaled_gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest) { + ////////////////////////////////////////////////////////////////////////////////////////////////////////// + // SFVectorSize = 32 with MxF4 instructions + ////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // (float_e2m1_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + // With SF generation reference + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/, + 16 /*EpiSFVecSize*/ + >(manifest); + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/, + 16 /*EpiSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpiSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpiSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpiSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e2m1_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpiSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpiSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpiSFVecSize*/ + >(manifest); + +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/block_scaled_gemm_mixed8bitsa.cu b/tools/library/src/reference/block_scaled_gemm_mixed8bitsa.cu new file mode 100644 index 0000000000..d2fefa8ae0 --- /dev/null +++ b/tools/library/src/reference/block_scaled_gemm_mixed8bitsa.cu @@ -0,0 +1,354 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "block_scaled_gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest) { + + ////////////////////////////////////////////////////////////////////////////////////////////////////////// + // SFVectorSize = 32 with MxF8F6F4 instructions + ////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // (float_e2m3_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + // (float_e4m3_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + // (float_e2m3_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + // (float_e2m1_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + // (float_e4m3_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + // (float_e4m3_t * float_ue8m0_t) * (float_e4m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e4m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e4m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + + + // (float_e3m2_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + + make_block_scaled_gemm< + float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e3m2_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + + // (float_e2m1_t * float_ue8m0_t) * (float_e2m3_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m1_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m3_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + + + // (float_e2m3_t * float_ue8m0_t) * (float_e2m1_t * float_ue8m0_t) + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + void /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, void /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/ + >(manifest); + + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e3m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + make_block_scaled_gemm< + float_e2m3_t /*A*/, float_ue8m0_t /*SFA*/, float_e2m1_t /*B*/, float_ue8m0_t /*SFB*/, + half_t /*C*/, float /*Compute*/, float_ue8m0_t /*SFD*/, float /*Accum*/, float_e5m2_t /*D*/, 32 /*SFVecSize*/, + 32 /*EpilogueSFVecSize*/ + >(manifest); + + +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/tools/library/src/reference/block_scaled_gemm_reference_operation.h new file mode 100644 index 0000000000..e1b9fec815 --- /dev/null +++ b/tools/library/src/reference/block_scaled_gemm_reference_operation.h @@ -0,0 +1,459 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +namespace detail { +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ElementSFD_ = void, + typename LayoutSFD_ = LayoutC_, + int SFVecSize_ = 32, + int EpilogueSFVecSize_ = 0, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockScaledGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementSFD = ElementSFD_; + using LayoutSFD = LayoutSFD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + constexpr static int SFVecSize = SFVecSize_; + constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockScaledGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + description_.SFD = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFVecSize = SFVecSize; + description_.EpilogueSFVecSize = EpilogueSFVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockScaledGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + using Sm100BlockScaledConfig = cutlass::detail::Sm100BlockScaledConfig; + auto A = cute::make_tensor(detail::make_iterator(static_cast(args.A)), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), Sm100BlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + + auto B = cute::make_tensor(detail::make_iterator(static_cast(args.B)), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), Sm100BlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(detail::make_iterator(static_cast(args.C)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(detail::make_iterator(static_cast(nullptr)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + if constexpr (not is_same_v) { + + using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig< + EpilogueSFVecSize + >; + + auto SfD = cute::make_tensor(detail::make_iterator(static_cast(args.SFD)), Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD), Int, cutlass::reference::host::SfStrategy::SfDGen> + epilogue_params{alpha, beta, C, D, SfD, *(static_cast(args.norm_constant))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + // W/O SF generation + auto SfD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L))); // not used. + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD)> + epilogue_params{alpha, beta, C, D, SfD}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + return Status::kSuccess; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm_tn(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm(Manifest &manifest) { + /// + /// A is Row , B is Col + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + /// + /// A is Col , B is Row + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f4_f4_f32.cu b/tools/library/src/reference/gemm_f4_f4_f32.cu new file mode 100644 index 0000000000..7e1e67c46f --- /dev/null +++ b/tools/library/src/reference/gemm_f4_f4_f32.cu @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A/B : float_e2m1_t (not support float_e0m2_t to reduce ref kernel compile time) +// Acc: f32 +// C/D : some variance + +// 1. e2m1_e2m1_f32_f16_e4m3 +// 2. e2m1_e2m1_f32_f16_e5m2 +// 3. e2m1_e2m1_f32_f16_f16 +// 4. e2m1_e2m1_f32_f32_f32 + +void initialize_gemm_reference_operations_f4_f4_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e2m1_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f4_f6_f32.cu b/tools/library/src/reference/gemm_f4_f6_f32.cu new file mode 100644 index 0000000000..d5a48b7293 --- /dev/null +++ b/tools/library/src/reference/gemm_f4_f6_f32.cu @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A: float_e2m1_t +// B: float_e3m2_t +// Acc: f32 +// C/D : some variance + +// 1. e2m1_e3m2_f32_f16_e4m3 +// 2. e2m1_e3m2_f32_f16_e5m2 +// 3. e2m1_e3m2_f32_f16_f16 +// 4. e2m1_e3m2_f32_f32_f32 + +void initialize_gemm_reference_operations_f4_f6_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e3m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f4_f8_f32.cu b/tools/library/src/reference/gemm_f4_f8_f32.cu new file mode 100644 index 0000000000..c4fec090e0 --- /dev/null +++ b/tools/library/src/reference/gemm_f4_f8_f32.cu @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A: float_e2m1_t +// B: float_e4m3_t +// Acc: f32 +// C/D : some variance + +// 1. e2m1_e4m3_f32_f16_e4m3 +// 2. e2m1_e4m3_f32_f16_e5m2 +// 3. e2m1_e4m3_f32_f16_f16 +// 4. e2m1_e4m3_f32_f32_f32 + +void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e2m1_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f6_f4_f32.cu b/tools/library/src/reference/gemm_f6_f4_f32.cu new file mode 100644 index 0000000000..db03eaf89d --- /dev/null +++ b/tools/library/src/reference/gemm_f6_f4_f32.cu @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A: float_e3m2_t +// B: float_e2m1_t +// Acc: f32 +// C/D : some variance + +// 1. e3m2_e2m1_f32_f16_e4m3 +// 2. e3m2_e2m1_f32_f16_e5m2 +// 3. e3m2_e2m1_f32_f16_f16 +// 4. e3m2_e2m1_f32_f32_f32 + +void initialize_gemm_reference_operations_f6_f4_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e2m1_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f6_f6_f32.cu b/tools/library/src/reference/gemm_f6_f6_f32.cu new file mode 100644 index 0000000000..e090da9d07 --- /dev/null +++ b/tools/library/src/reference/gemm_f6_f6_f32.cu @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A/B : float_e3m2_t (not support float_e2m3_t to reduce ref kernel compile time) +// Acc: f32 +// C/D : some variance + +// 1. e3m2_e3m2_f32_f16_e4m3 +// 2. e3m2_e3m2_f32_f16_e5m2 +// 3. e3m2_e3m2_f32_f16_f16 +// 4. e3m2_e3m2_f32_f32_f32 + +void initialize_gemm_reference_operations_f6_f6_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e3m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f6_f8_f32.cu b/tools/library/src/reference/gemm_f6_f8_f32.cu new file mode 100644 index 0000000000..cedeae2c0d --- /dev/null +++ b/tools/library/src/reference/gemm_f6_f8_f32.cu @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A: float_e3m2_t +// B: float_e4m3_t +// Acc: f32 +// C/D : some variance + +// 1. e3m2_e4m3_f32_f16_e4m3 +// 2. e3m2_e4m3_f32_f16_e5m2 +// 3. e3m2_e4m3_f32_f16_f16 +// 4. e3m2_e4m3_f32_f32_f32 + +void initialize_gemm_reference_operations_f6_f8_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e3m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f8_f4_f32.cu b/tools/library/src/reference/gemm_f8_f4_f32.cu new file mode 100644 index 0000000000..064cd226bf --- /dev/null +++ b/tools/library/src/reference/gemm_f8_f4_f32.cu @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A: float_e4m3_t +// B: float_e2m1_t +// Acc: f32 +// C/D : some variance + +// 1. e4m3_e2m1_f32_f16_e4m3 +// 2. e4m3_e2m1_f32_f16_e5m2 +// 3. e4m3_e2m1_f32_f16_f16 +// 4. e4m3_e2m1_f32_f32_f32 + +void initialize_gemm_reference_operations_f8_f4_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e2m1_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e2m1_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_f8_f6_f32.cu b/tools/library/src/reference/gemm_f8_f6_f32.cu new file mode 100644 index 0000000000..0a5bb9e95f --- /dev/null +++ b/tools/library/src/reference/gemm_f8_f6_f32.cu @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + + + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// A: float_e4m3_t +// B: float_e3m2_t +// Acc: f32 +// C/D : some variance + +// 1. e4m3_e3m2_f32_f16_e4m3 +// 2. e4m3_e3m2_f32_f16_e5m2 +// 3. e4m3_e3m2_f32_f16_f16 +// 4. e4m3_e3m2_f32_f32_f32 + +void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest) { + + // 1. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + // 2. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + // 3. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e3m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + // 4. + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e3m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_u8_u8_s32.cu b/tools/library/src/reference/gemm_u8_u8_s32.cu index b3b1acf86c..1f2786eb3a 100644 --- a/tools/library/src/reference/gemm_u8_u8_s32.cu +++ b/tools/library/src/reference/gemm_u8_u8_s32.cu @@ -87,6 +87,17 @@ void initialize_gemm_reference_operations_u8_u8_s32(Manifest &manifest) { NumericConverterClamp // From Scalar to D >(manifest); + // 4. + make_gemm_real_canonical_layouts< + uint8_t, // ElementA + uint8_t, // ElementB + int8_t, // ElementC + float, // ElementScalar / ElementCompute + int32_t, // ElementAccumulator + uint8_t, // ElementD + NumericConverterClamp // From Scalar to D + >(manifest); + } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index f5dcc67c4e..ad994acb36 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -52,6 +52,19 @@ void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest); void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest); void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest); void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest); + +void initialize_gemm_reference_operations_f4_f4_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f4_f6_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f4_f8_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f6_f4_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f6_f6_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f6_f8_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f8_f4_f32(Manifest &manifest); +void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest); +void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest); +void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest); +void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest); + void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest); void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest); void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest); @@ -89,6 +102,19 @@ void initialize_reference_operations(Manifest &manifest) { initialize_gemm_reference_operations_fp_mixed_input(manifest); initialize_gemm_reference_operations_int_mixed_input(manifest); + + initialize_gemm_reference_operations_f4_f4_f32(manifest); + initialize_gemm_reference_operations_f4_f6_f32(manifest); + initialize_gemm_reference_operations_f4_f8_f32(manifest); + initialize_gemm_reference_operations_f6_f4_f32(manifest); + initialize_gemm_reference_operations_f6_f6_f32(manifest); + initialize_gemm_reference_operations_f6_f8_f32(manifest); + initialize_gemm_reference_operations_f8_f4_f32(manifest); + initialize_gemm_reference_operations_f8_f6_f32(manifest); + initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest); + initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest); + initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest); + } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/sparse_gemm_operation_3x.hpp b/tools/library/src/sparse_gemm_operation_3x.hpp index 2fc51ff66c..c77c236364 100644 --- a/tools/library/src/sparse_gemm_operation_3x.hpp +++ b/tools/library/src/sparse_gemm_operation_3x.hpp @@ -211,6 +211,10 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { } } + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + return status; } diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index f841781ab3..5b39e74954 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -334,6 +334,7 @@ static struct { OperationKind_enumerants[] = { {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, {"gemm", "Gemm", OperationKind::kGemm}, + {"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm}, {"rank_k", "RankK", OperationKind::kRankK}, {"rank_2k", "Rank2K", OperationKind::kRank2K}, {"trmm", "Trmm", OperationKind::kTrmm}, @@ -422,6 +423,53 @@ Status from_string(std::string const &str) { /////////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const *text; + char const *pretty; + RuntimeDatatype enumerant; +} +RuntimeDatatype_enumerants[] = { + {"e4m3", "", RuntimeDatatype::kE4M3}, + {"e5m2", "", RuntimeDatatype::kE5M2}, + {"e3m2", "", RuntimeDatatype::kE3M2}, + {"e2m3", "", RuntimeDatatype::kE2M3}, + {"e2m1", "", RuntimeDatatype::kE2M1} +}; + +/// Converts a RuntimeDatatype enumerant to a string +char const *to_string(RuntimeDatatype type, bool pretty) { + + for (auto const & possible : RuntimeDatatype_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + + +/// Converts a RuntimeDatatype enumerant from a string +template <> +RuntimeDatatype from_string(std::string const &str) { + + for (auto const & possible : RuntimeDatatype_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return RuntimeDatatype::kInvalid; +} + + /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { @@ -447,6 +495,16 @@ NumericTypeID_enumerants[] = { {"s64", "S64", NumericTypeID::kS64}, {"fe4m3", "FE4M3", NumericTypeID::kFE4M3}, {"fe5m2", "FE5M2", NumericTypeID::kFE5M2}, + + {"f8", "F8", NumericTypeID::kF8}, + {"f6", "F6", NumericTypeID::kF6}, + {"f4", "F4", NumericTypeID::kF4}, + {"fe2m3", "FE2M3", NumericTypeID::kFE2M3}, + {"fe3m2", "FE3M2", NumericTypeID::kFE3M2}, + {"fe2m1", "FE2M1", NumericTypeID::kFE2M1}, + {"fue8m0", "FUE8M0", NumericTypeID::kFUE8M0}, + {"fue4m3", "FUE4M3", NumericTypeID::kFUE4M3}, + {"f16", "F16", NumericTypeID::kF16}, {"bf16", "BF16", NumericTypeID::kBF16}, {"f32", "F32", NumericTypeID::kF32}, @@ -510,6 +568,16 @@ int sizeof_bits(NumericTypeID type) { switch (type) { case NumericTypeID::kFE4M3: return 8; case NumericTypeID::kFE5M2: return 8; + + case NumericTypeID::kF8: return 8; + case NumericTypeID::kF6: return 6; + case NumericTypeID::kF4: return 4; + case NumericTypeID::kFE2M3: return 6; + case NumericTypeID::kFE3M2: return 6; + case NumericTypeID::kFE2M1: return 4; + case NumericTypeID::kFUE8M0: return 8; + case NumericTypeID::kFUE4M3: return 8; + case NumericTypeID::kF16: return 16; case NumericTypeID::kBF16: return 16; case NumericTypeID::kTF32: return 32; @@ -589,6 +657,16 @@ bool is_signed_type(NumericTypeID type) { switch (type) { case NumericTypeID::kFE4M3: return true; case NumericTypeID::kFE5M2: return true; + + case NumericTypeID::kF8: return true; + case NumericTypeID::kF6: return true; + case NumericTypeID::kF4: return true; + case NumericTypeID::kFE2M3: return true; + case NumericTypeID::kFE3M2: return true; + case NumericTypeID::kFE2M1: return true; + case NumericTypeID::kFUE8M0: return false; + case NumericTypeID::kFUE4M3: return false; + case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; @@ -620,6 +698,16 @@ bool is_float_type(NumericTypeID type) { switch (type) { case NumericTypeID::kFE4M3: return true; case NumericTypeID::kFE5M2: return true; + + case NumericTypeID::kF8: return true; + case NumericTypeID::kF6: return true; + case NumericTypeID::kF4: return true; + case NumericTypeID::kFE2M3: return true; + case NumericTypeID::kFE3M2: return true; + case NumericTypeID::kFE2M1: return true; + case NumericTypeID::kFUE8M0: return true; + case NumericTypeID::kFUE4M3: return true; + case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; @@ -1168,6 +1256,43 @@ bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string c *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; + + case NumericTypeID::kFE2M3: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kFE3M2: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kFE2M1: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kFUE8M0: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kFUE4M3: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kF16: { float tmp; @@ -1317,6 +1442,38 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { ss << tmp; } break; + + case NumericTypeID::kFE2M3: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; + case NumericTypeID::kFE3M2: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; + case NumericTypeID::kFE2M1: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; + case NumericTypeID::kFUE8M0: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; + case NumericTypeID::kFUE4M3: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; + case NumericTypeID::kF16: { float tmp = *reinterpret_cast(bytes.data()); @@ -1469,6 +1626,33 @@ bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t sr *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; + + case NumericTypeID::kFE2M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE3M2: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE2M1: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFUE8M0: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFUE4M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1579,6 +1763,33 @@ bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; + + case NumericTypeID::kFE2M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE3M2: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE2M1: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFUE8M0: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFUE4M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1690,6 +1901,33 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; + + case NumericTypeID::kFE2M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE3M2: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE2M1: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFUE8M0: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFUE4M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1751,6 +1989,35 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr return true; } + +NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type) { + NumericTypeID element{}; + switch (type) { + case RuntimeDatatype::kE4M3: + element = NumericTypeID::kFE4M3; + break; + case RuntimeDatatype::kE5M2: + element = NumericTypeID::kFE5M2; + break; + + case RuntimeDatatype::kE2M3: + element = NumericTypeID::kFE2M3; + break; + case RuntimeDatatype::kE3M2: + element = NumericTypeID::kFE3M2; + break; + case RuntimeDatatype::kE2M1: + element = NumericTypeID::kFE2M1; + break; + + default: + assert("illegal runtime datatype!"); + break; + } + return element; +} + + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 7038289da7..0e257413c4 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -46,6 +46,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES src/problem_space.cpp src/operation_profiler.cu src/gemm_operation_profiler.cu + src/block_scaled_gemm_operation_profiler.cu src/rank_k_operation_profiler.cu src/rank_2k_operation_profiler.cu src/trmm_operation_profiler.cu @@ -101,6 +102,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) else() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) endif() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true) diff --git a/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h new file mode 100644 index 0000000000..f0e73cc584 --- /dev/null +++ b/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockScaledGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + DeviceAllocation *Computed_SFD{nullptr}; + DeviceAllocation *Reference_SFD{nullptr}; + DeviceAllocation *Norm_constant{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockScaledGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockScaledGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockScaledGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index b87b73f8a0..18b18cbc88 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -73,6 +73,15 @@ class GemmOperationProfiler : public OperationProfiler { int64_t n{16}; int64_t k{16}; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + int64_t lda{0}; int64_t ldb{0}; int64_t ldc{0}; @@ -86,6 +95,11 @@ class GemmOperationProfiler : public OperationProfiler { cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; int swizzle_size{1}; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + // gemm with parallel interleaved reduction // gemm epilogue (alpha, beta) = (1.0, 0.0) // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) diff --git a/tools/profiler/include/cutlass/profiler/problem_space.h b/tools/profiler/include/cutlass/profiler/problem_space.h index 8a5f001d0b..03903e3eec 100644 --- a/tools/profiler/include/cutlass/profiler/problem_space.h +++ b/tools/profiler/include/cutlass/profiler/problem_space.h @@ -942,6 +942,18 @@ bool arg_as_IteratorAlgorithmID( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr); diff --git a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu new file mode 100644 index 0000000000..81ea052227 --- /dev/null +++ b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu @@ -0,0 +1,1371 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + + + +#include +#include +#include +#include +#include + +#include "cutlass/core_io.h" + +#include "cutlass/profiler/cublas_helpers.h" +#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h" +#include "cutlass/profiler/gpu_timer.h" +#include "cutlass/library/singleton.h" +#include "cutlass/library/library.h" +#include "cutlass/library/handle.h" + +#include "cutlass/util/reference/host/gett.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Ctor +BlockScaledGemmOperationProfiler::BlockScaledGemmOperationProfiler(Options const &options): + OperationProfiler( + options, + library::OperationKind::kBlockScaledGemm, + { + {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (universal, gemm, planar_complex, planar_complex_array)"}, + {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, + {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, + {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, + {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, + {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, + {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, + {ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D output"}, + {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, + {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, + // TODO: Bring these back once SM100 future audits are complete + {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"}, + {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, + {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, + {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, + {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, + {ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"}, + {ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"}, + {ArgumentTypeID::kEnumerated, {"use_pdl", "use_pdl"}, "Use PDL (true, false)"}, + }, + { library::Provider::kCUBLAS} + ) { + + description_ = " General matrix-matrix product. D = alpha * A*B + beta * C"; +} + +/// Destructor +BlockScaledGemmOperationProfiler::~BlockScaledGemmOperationProfiler() { + +} + +/// Prints usage statement for the math function +void BlockScaledGemmOperationProfiler::print_usage(std::ostream &out) const { + out << "Block Scaled GEMM" << "\n\n"; + + OperationProfiler::print_usage(out); +} + +/// Prints examples +void BlockScaledGemmOperationProfiler::print_examples(std::ostream &out) const { + + out << "\nExamples:\n\n" + << "Profile a particular problem size:\n" + << " $ cutlass_profiler --operation=block_scaled_gemm --m=1024 --n=1024 --k=128\n\n" + + << "Schmoo over problem size and beta:\n" + << " $ cutlass_profiler --operation=block_scaled_gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" + + // TODO: Bring these back once SM100 future audits are complete +#if 0 + << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" + << " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n" + + << "Profile a particular problem size with split K and parallel reduction:\n" + << " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" +#endif + + << "Using various input value distribution:\n" + << " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n" + << " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n" + << " $ cutlass_profiler --operation=Gemm --dist=sequential,start:0,delta:1\n\n" + + << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" + << " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" + + << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" + << " $ cutlass_profiler --operation=Gemm \\ \n" + << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" + << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" + << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" + << " --beta=0,1,2 --profiling-iterations=1 \\ \n" + << " --providers=cutlass --output=functional-test.csv\n\n"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if 0 +// used this for debugging +static std::string byte_string(std::vector const &bytes) { + std::stringstream ss; + + ss << "0x"; + + for (size_t idx = bytes.size(); idx > 0; --idx) { + ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); + } + + return ss.str(); +} +#endif + +Status BlockScaledGemmOperationProfiler::GemmProblem::parse( + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + this->mode = library::GemmUniversalMode::kGemm; + + if (!arg_as_int(this->m, "m", problem_space, problem)) { + // default value + this->m = 1024; + } + + if (!arg_as_int(this->n, "n", problem_space, problem)) { + // default value + this->n = 1024; + } + + if (!arg_as_int(this->k, "k", problem_space, problem)) { + // default value + this->k = 1024; + } + + + if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { + // default value + this->cluster_m = 1; + } + + if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { + // default value + this->cluster_n = 1; + } + + if (!arg_as_int(this->cluster_k, "cluster_k", problem_space, problem)) { + // default value + this->cluster_k = 1; + } + + if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { + // default value + this->cluster_m_fallback = 0; + } + + if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { + // default value + this->cluster_n_fallback = 0; + } + + if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { + // default value + this->cluster_k_fallback = 0; + } + + + if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { + // default value + this->split_k_mode = library::SplitKMode::kSerial; + } + + this->mode = library::GemmUniversalMode::kGemm; + if (this->split_k_mode == library::SplitKMode::kParallel) { + this->mode = library::GemmUniversalMode::kGemmSplitKParallel; + } + + if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { + // default value + this->split_k_slices = 1; + } + + // TODO: Bring these back once SM100 future audits are complete + if (this->split_k_mode != library::SplitKMode::kSerial) { + std::cout<<"SplitK/StreamK feature is not supported yet!"; + return Status::kErrorInvalidProblem; + } + + if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) { + // default value + this->use_pdl = false; + } + + + if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) { + // default value + this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic; + } + + if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) { + // default value + this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic; + } + + + if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { + // default value + this->batch_count = 1; + } else if (this->batch_count > 1) { + this->mode = library::GemmUniversalMode::kBatched; + } + + if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) { + // default value + this->swizzle_size = 1; + } + + if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) { + // default value + this->raster_order = library::RasterOrder::kHeuristic; + } + + if (this->split_k_slices > 1 && this->batch_count > 1) { + // At least one of these must be one + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.D, "D", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!arg_as_scalar( + this->alpha, + operation_desc.element_epilogue, + "alpha", + problem_space, + problem)) { + + if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { + return Status::kErrorInternal; + } + } + + if (!arg_as_scalar( + this->beta, + operation_desc.element_epilogue, + "beta", + problem_space, + problem)) { + + if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { + return Status::kErrorInternal; + } + } + + this->lda = DeviceAllocation::get_packed_layout( + operation_desc.A.layout, {int(this->m), int(this->k)}).front(); + + this->ldb = DeviceAllocation::get_packed_layout( + operation_desc.B.layout, {int(this->k), int(this->n)}).front(); + + this->ldc = DeviceAllocation::get_packed_layout( + operation_desc.C.layout, {int(this->m), int(this->n)}).front(); + + return Status::kSuccess; +} + +/// Total number of bytes loaded +int64_t BlockScaledGemmOperationProfiler::GemmProblem::bytes(library::BlockScaledGemmDescription const &operation_desc) const { + // Input bytes read and Output bytes written for the gemm problem + int64_t bytes = + int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * k + + int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + + int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; + + // Set is_beta_zero true if beta is zero + bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); + + // Output bytes read for the gemm problem for non-zero beta values + if (!is_beta_zero) { + bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; + } + + bytes *= batch_count; + + return bytes; +} + +/// Total number of flops computed +int64_t BlockScaledGemmOperationProfiler::GemmProblem::flops(library::BlockScaledGemmDescription const &operation_desc) const { + int64_t flops_ = (int64_t(m) * n * k + m * n) * 2 * batch_count; + + // complex-valued support + switch (operation_desc.tile_description.math_instruction.math_operation) { + case library::MathOperationID::kMultiplyAddComplex: + flops_ *= 4; + break; + + case library::MathOperationID::kMultiplyAddComplexFastF32: + flops_ *= 4; + break; + + case library::MathOperationID::kMultiplyAddGaussianComplex: + flops_ *= 3; + break; + + default: break; + } + + return flops_; +} + + +/// Initializes a performance result +void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space) { + + result.arguments.resize(problem_space.rank()); + + set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); + + set_argument(result, "A", problem_space, + std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); + + set_argument(result, "B", problem_space, + std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); + + set_argument(result, "C", problem_space, + std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); + + set_argument(result, "D", problem_space, + std::string(library::to_string(operation_desc.D.element)) + ":" + library::to_string(operation_desc.D.layout)); + + set_argument(result, "m", problem_space, m); + set_argument(result, "n", problem_space, n); + set_argument(result, "k", problem_space, k); + + + set_argument(result, "cluster_m", problem_space, cluster_m); + set_argument(result, "cluster_n", problem_space, cluster_n); + set_argument(result, "cluster_k", problem_space, cluster_k); + set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback); + set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback); + set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback); + + + // TODO: Bring these back once SM100 future audits are complete + set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); + set_argument(result, "split_k_slices", problem_space, split_k_slices); + set_argument(result, "batch_count", problem_space, batch_count); + set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); + set_argument(result, "swizzle_size", problem_space, swizzle_size); + set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl)); + + + set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a)); + set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b)); + + + set_argument(result, "alpha", problem_space, + library::lexical_cast(alpha, operation_desc.element_epilogue)); + + set_argument(result, "beta", problem_space, + library::lexical_cast(beta, operation_desc.element_epilogue)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Extracts the problem dimensions +Status BlockScaledGemmOperationProfiler::initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + library::BlockScaledGemmDescription const &operation_desc = + static_cast(operation->description()); + + if (operation_desc.gemm_kind != library::GemmKind::kUniversal) { + return Status::kErrorInvalidProblem; + } + + Status status = problem_.parse(operation_desc, problem_space, problem); + + if (status != Status::kSuccess) { + return status; + } + + gemm_workspace_.configuration.mode = problem_.mode; + gemm_workspace_.configuration.problem_size.m() = int(problem_.m); + gemm_workspace_.configuration.problem_size.n() = int(problem_.n); + gemm_workspace_.configuration.problem_size.k() = int(problem_.k); + + gemm_workspace_.configuration.cluster_shape.m() = int(problem_.cluster_m); + gemm_workspace_.configuration.cluster_shape.n() = int(problem_.cluster_n); + gemm_workspace_.configuration.cluster_shape.k() = int(problem_.cluster_k); + gemm_workspace_.configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback); + gemm_workspace_.configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback); + gemm_workspace_.configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback); + + gemm_workspace_.configuration.lda = problem_.lda; + gemm_workspace_.configuration.ldb = problem_.ldb; + gemm_workspace_.configuration.ldc = problem_.ldc; + gemm_workspace_.configuration.ldd = problem_.ldc; + + if (problem_.mode == library::GemmUniversalMode::kBatched) { + gemm_workspace_.configuration.batch_count = problem_.batch_count; + } + else { + gemm_workspace_.configuration.batch_count = problem_.split_k_slices; + } + + gemm_workspace_.arguments.problem_size.m() = int(problem_.m); + gemm_workspace_.arguments.problem_size.n() = int(problem_.n); + gemm_workspace_.arguments.problem_size.k() = int(problem_.k); + gemm_workspace_.arguments.batch_count = problem_.batch_count; + + gemm_workspace_.arguments.A = nullptr; + gemm_workspace_.arguments.B = nullptr; + gemm_workspace_.arguments.C = nullptr; + gemm_workspace_.arguments.D = nullptr; + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size; + gemm_workspace_.arguments.raster_order = problem_.raster_order; + gemm_workspace_.arguments.norm_constant = 0; + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + gemm_workspace_.arguments.split_k_slices = problem_.split_k_slices; + + + gemm_workspace_.arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a; + gemm_workspace_.arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b; + + + gemm_workspace_.arguments.use_pdl = problem_.use_pdl; + + // initialize reduction operation for parallel splitKMode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!initialize_reduction_configuration_(operation, problem)) { + return Status::kErrorInternal; + } + } + + initialize_result_(this->model_result_, options, operation_desc, problem_space); + + return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); +} + +/// Initializes the performance result +void BlockScaledGemmOperationProfiler::initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space) { + + result.provider = library::Provider::kCUTLASS; + result.disposition = Disposition::kNotRun; + result.status = Status::kSuccess; + result.operation_name = operation_desc.name; + + problem_.initialize_result(result, operation_desc, problem_space); + + OperationProfiler::initialize_result_(result, operation_desc, problem_space); + + result.bytes = problem_.bytes(operation_desc); + result.flops = problem_.flops(operation_desc); + result.runtime = 0; + +} + +/// Initialize reduction problem dimensions and library::Operation +bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem) { + + // TODO: Bring these back once SM100 future audits are complete +#if 1 + library::BlockScaledGemmDescription const &gemm_desc = + static_cast(operation->description()); + + if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) { + return false; + } + + if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) { + return false; + } + + /// initialize library::ReductionConfiguration + gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); + gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); + gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); + gemm_workspace_.reduction_configuration.ldw = problem_.ldc; + gemm_workspace_.reduction_configuration.lds = problem_.ldc; + gemm_workspace_.reduction_configuration.ldd = problem_.ldc; + + // find reduction operation + library::ReductionFunctionalKey reduction_key( + library::Provider::kCUTLASS, + gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace + gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator + gemm_desc.D.element, // element output + gemm_desc.element_epilogue // element compute + ); + + auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key); + + if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) { + return false; + } + + // initialize reduction operation required for parallel split-k operator + reduction_op_ = reduction_it->second; + + // reduction operation found and initialized + return true; +#endif + return false; +} + +/// Initializes workspace +Status BlockScaledGemmOperationProfiler::initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + if (options.device.devices.size() != 1) { + throw std::runtime_error("This operation profiler only supports a single " + "device."); + } + + cudaError_t result; + result = cudaSetDevice(options.device.device_id(0)); + if (result != cudaSuccess) { + throw std::runtime_error("cudaSetDevice() failed."); + } + + library::Operation const* underlying_operation = operation; + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { + return Status::kErrorNotSupported; + } + } + + library::BlockScaledGemmDescription const &operation_desc = + static_cast(operation->description()); + + // Compute the number of copies of the problem to avoid L2 camping. + if (!options.profiling.workspace_count) { + int64_t bytes = problem_.bytes(operation_desc); + if (bytes < 3 * int64_t(options.device.properties[0].l2CacheSize)) { + gemm_workspace_.problem_count = + 1 + int((3 * int64_t(options.device.properties[0].l2CacheSize)) / bytes); + } + else { + gemm_workspace_.problem_count = 1; + } + } + else { + gemm_workspace_.problem_count = options.profiling.workspace_count; + } + + bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; + if (allocate_device_tensors) { + int seed_shift = 0; + gemm_workspace_.A = device_context.allocate_and_initialize_tensor( + options, + "A", + operation_desc.A.element, + operation_desc.A.layout, + {int(problem_.m), int(problem_.k)}, + {int(problem_.lda)}, + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++, + 0 // device_index + ); + + int sfa_m = round_up(int(problem_.m), 128); + int sfb_n = round_up(int(problem_.n), 128); + int sfa_sfb_k = round_up(ceil_div(int(problem_.k), operation_desc.SFVecSize), 4); + + int sfd_m = operation_desc.SFD.layout == cutlass::library::LayoutTypeID::kRowMajor ? + sfa_m : round_up(ceil_div(int(problem_.m), operation_desc.EpilogueSFVecSize), 4); + int sfd_n = operation_desc.SFD.layout == cutlass::library::LayoutTypeID::kRowMajor ? + round_up(ceil_div(int(problem_.n), operation_desc.EpilogueSFVecSize), 4) : sfb_n; + + + gemm_workspace_.SFA = device_context.allocate_and_initialize_tensor( + options, + "SFA", + operation_desc.SFA.element, + operation_desc.SFA.layout, + {sfa_m, sfa_sfb_k}, + {sfa_sfb_k}, + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++, + 0 // device_index + ); + + gemm_workspace_.SFB = device_context.allocate_and_initialize_tensor( + options, + "SFB", + operation_desc.SFB.element, + operation_desc.SFB.layout, + {sfb_n, sfa_sfb_k}, + {sfa_sfb_k}, + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++, + 0 // device_index + ); + + gemm_workspace_.B = device_context.allocate_and_initialize_tensor( + options, + "B", + operation_desc.B.element, + operation_desc.B.layout, + {int(problem_.k), int(problem_.n)}, + {int(problem_.ldb)}, + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++, + 0 // device_index + ); + + gemm_workspace_.C = device_context.allocate_and_initialize_tensor( + options, + "C", + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++, + 0 // device_index + ); + + gemm_workspace_.Computed = device_context.allocate_tensor( + options, + "D", + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_.problem_count, + 0 // device_index + ); + + gemm_workspace_.Reference = device_context.allocate_tensor( + options, + "Reference", + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_.problem_count, + 0 // device_index + ); + + gemm_workspace_.Computed_SFD = device_context.allocate_tensor( + options, + "SFD", + operation_desc.SFD.element, + operation_desc.SFD.layout, + {sfd_m, sfd_n}, + {sfd_n}, + problem_.batch_count * gemm_workspace_.problem_count, + 0 // device_index + ); + + gemm_workspace_.Reference_SFD = device_context.allocate_tensor( + options, + "Reference_SFD", + operation_desc.SFD.element, + operation_desc.SFD.layout, + {sfd_m, sfd_n}, + {sfd_n}, + problem_.batch_count * gemm_workspace_.problem_count, + 0 // device_index + ); + + gemm_workspace_.Norm_constant = device_context.allocate_and_initialize_tensor( + options, + "Norm_constant", + operation_desc.element_epilogue, + operation_desc.A.layout, + {1, 1}, + {1}, + 1, + seed_shift++, + 0 // device_index + ); + + } + + if (options.execution_mode != ExecutionMode::kDryRun) { + + // ScaleFactor tensor results may have some holes and will not be touched by the kernel. + // If we randomly fill the two tensors, these holes may encounter refcheck errors. + if (gemm_workspace_.Computed_SFD->type() != library::NumericTypeID::kVoid) { + if (options.initialization.provider == library::Provider::kReferenceHost) { + gemm_workspace_.Reference_SFD->fill_host(0); + gemm_workspace_.Computed_SFD->fill_host(0); + } + else { + gemm_workspace_.Reference_SFD->fill_device(0); + gemm_workspace_.Computed_SFD->fill_device(0); + } + } + + + // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels + gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + gemm_workspace_.arguments.split_k_slices = problem_.split_k_slices; + gemm_workspace_.arguments.batch_count = problem_.batch_count; + gemm_workspace_.arguments.lda = problem_.lda; + gemm_workspace_.arguments.ldb = problem_.ldb; + gemm_workspace_.arguments.ldc = problem_.ldc; + gemm_workspace_.arguments.ldd = problem_.ldc; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + gemm_workspace_.arguments.use_pdl = problem_.use_pdl; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount; + } + + // + // Initialize the CUTLASS operation + // + Status status = Status::kSuccess; + + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + + if (options.execution_mode != ExecutionMode::kDryRun) { + uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); + gemm_workspace_.host_workspace.resize(workspace_size, 0); + + workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, + &gemm_workspace_.arguments); + gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + + status = underlying_operation->initialize( + &gemm_workspace_.configuration, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + if (status != Status::kSuccess) { + return status; + } + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration); + gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0); + + status = reduction_op_->initialize( + &gemm_workspace_.reduction_configuration, + gemm_workspace_.reduction_host_workspace.data(), + nullptr); + + if (status != Status::kSuccess) { + return status; + } + } + } + + // + // If CUTLASS is enabled, generate a result for it + // + results_.push_back(model_result_); + results_.back().provider = library::Provider::kCUTLASS; + results_.back().op_kind = library::OperationKind::kGemm; + results_.back().disposition = Disposition::kNotRun; + + for (auto provider : verification_providers_) { + results_.back().verification_map[provider] = Disposition::kNotRun; + } + } + return status; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Verifies CUTLASS against references +bool BlockScaledGemmOperationProfiler::verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + return true; + } + + if (options.execution_mode == ExecutionMode::kDryRun) { + return true; + } + + // Initialize structure containing GEMM arguments + gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.SFA = gemm_workspace_.SFA->data(); + gemm_workspace_.arguments.SFB = gemm_workspace_.SFB->data(); + gemm_workspace_.arguments.C = gemm_workspace_.C->data(); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); + gemm_workspace_.arguments.SFD = gemm_workspace_.Computed_SFD->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.norm_constant = gemm_workspace_.Norm_constant->data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); + gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } + + // + // Run the CUTLASS operation + // + + // initialize gemm underlying operation to handle parallel reduction + library::Operation const * underlying_operation = operation; + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { + results_.back().disposition = Disposition::kFailed; + return false; + } + } + + results_.back().status = underlying_operation->run( + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data(), + nullptr); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + results_.back().status = reduction_op_->run( + &gemm_workspace_.reduction_arguments, + gemm_workspace_.reduction_host_workspace.data(), + nullptr, + nullptr); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + + // CUTLASS op ran the but not yet verified against any verification provider + results_.back().disposition = Disposition::kNotVerified; + + // + // Run verification providers + // + + if (options.verification.enabled) { + +#if CUTLASS_ENABLE_CUBLAS + if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { + // set verification map for cublas to not supported + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; + } +#endif // #if CUTLASS_ENABLE_CUBLAS + + + cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.arguments.runtime_input_datatype_a; + cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.arguments.runtime_input_datatype_b; + + bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic; + bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic; + + assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static."); + + library::OperationDescription const &desc = operation->description(); + auto &gemm_desc = static_cast(desc); + + cutlass::library::NumericTypeID element_A = gemm_desc.A.element; + cutlass::library::NumericTypeID element_B = gemm_desc.B.element; + + if (is_runtime_datatype_a) { + element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a); + } + + if (is_runtime_datatype_b) { + element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b); + } + + + bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B); + + // Update disposition to worst case verification outcome among all + // verification providers which are supported + bool is_any_verification_run_passed = false; + for (auto &m : results_.back().verification_map) { + if (m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { + results_.back().disposition = m.second; + return true; + } + if (!is_any_verification_run_passed && m.second == Disposition::kPassed) { + is_any_verification_run_passed = true; + } + } + + if (is_any_verification_run_passed) { + results_.back().disposition = Disposition::kPassed; + } + } + + // if verification.required is set, then return success iff at least one ref-check was run + if (options.verification.required) { + bool did_any_verification_run = false; + for (auto provider : options.verification.providers) { + did_any_verification_run |= (Disposition::kNotRun != results_.back().verification_map[provider]); + } + + if (not did_any_verification_run) { + results_.back().status = Status::kErrorNotSupported; + return false; + } + } + + // Return true means continue profiling + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Verifies CUTLASS against references +bool BlockScaledGemmOperationProfiler::verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + +#if CUTLASS_ENABLE_CUBLAS + std::cerr << "cuBLAS is not supported" << std::endl; +#endif + + // Return true means continue profiling + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Verifies CUTLASS against host and device references +bool BlockScaledGemmOperationProfiler::verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B) { + + /// Verifies CUTLASS against host reference + + // + // Find host reference operation using conv2d functional description key + // + library::OperationDescription const &desc = operation->description(); + + auto &gemm_desc = static_cast(desc); + + library::BlockScaledGemmFunctionalKey blockScaledGemm_key( + library::Provider::kReferenceHost, + gemm_desc.gemm_kind, + gemm_desc.kind, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + element_A, + gemm_desc.A.layout, + gemm_desc.SFA.element, + element_B, + gemm_desc.B.layout, + gemm_desc.SFB.element, + gemm_desc.C.element, + gemm_desc.C.layout, + gemm_desc.D.element, + gemm_desc.D.layout, + gemm_desc.SFD.element, + gemm_desc.SFD.layout, + gemm_desc.SFVecSize + , gemm_desc.EpilogueSFVecSize + ); + + auto operators_it = library::Singleton::get().operation_table.block_scaled_gemm_operations.find(blockScaledGemm_key); + + if (operators_it == library::Singleton::get().operation_table.block_scaled_gemm_operations.end()) { + return true; + } + + if (operators_it->second.empty()) { + return true; + } + + // Not use preference to filter the reference kernel. + auto cc_it = operators_it->second.begin(); + + if(cc_it == operators_it->second.end()) { + std::cout<< "not find any reference kernel" << std::endl; + results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; + return true; + } + + // host reference has only one instances in BlockScaledOperationVectorMap + library::Operation const *reference_op = cc_it->second[0]; + + // To support the host-side reference, conditionally allocate and + // copy tensors to host memory. + std::vector host_data_A; + std::vector host_data_SFA; + std::vector host_data_B; + std::vector host_data_SFB; + std::vector host_data_C; + std::vector host_data_D; + std::vector host_data_SFD; + std::vector host_data_Norm_constant; + + // + // Copy input tensors A, B, and C from device to host buffers + // + + host_data_A.resize(gemm_workspace_.A->bytes()); + void * ptr_A = host_data_A.data(); + gemm_workspace_.A->copy_to_host(ptr_A); + + host_data_SFA.resize(gemm_workspace_.SFA->bytes()); + void * ptr_SFA = host_data_SFA.data(); + gemm_workspace_.SFA->copy_to_host(ptr_SFA); + + host_data_B.resize(gemm_workspace_.B->bytes()); + void * ptr_B = host_data_B.data(); + gemm_workspace_.B->copy_to_host(ptr_B); + + host_data_SFB.resize(gemm_workspace_.SFB->bytes()); + void * ptr_SFB = host_data_SFB.data(); + gemm_workspace_.SFB->copy_to_host(ptr_SFB); + + host_data_C.resize(gemm_workspace_.C->bytes()); + void * ptr_C = host_data_C.data(); + gemm_workspace_.C->copy_to_host(ptr_C); + + host_data_Norm_constant.resize(gemm_workspace_.Norm_constant->bytes()); + void * ptr_Norm_constant = host_data_Norm_constant.data(); + gemm_workspace_.Norm_constant->copy_to_host(ptr_Norm_constant); + + host_data_D.resize(gemm_workspace_.Reference->bytes()); + void * ptr_D = host_data_D.data(); + + host_data_SFD.resize(gemm_workspace_.Reference_SFD->bytes()); + void * ptr_SFD = host_data_SFD.data(); + + /// Set reference kernel Arguments + + library::BlockScaledGemmArguments arguments { + {int(problem_.m), int(problem_.n), int(problem_.k)}, + {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}, + {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}, + gemm_workspace_.configuration.batch_count, + ptr_A, + ptr_B, + ptr_SFA, + ptr_SFB, + ptr_C, + ptr_D, + ptr_SFD, + problem_.alpha.data(), + problem_.beta.data(), + library::ScalarPointerMode::kHost, + int(gemm_workspace_.configuration.lda), + int(gemm_workspace_.configuration.ldb), + int(gemm_workspace_.configuration.ldc), + int(gemm_workspace_.configuration.ldd), + gemm_workspace_.A->batch_stride(), + gemm_workspace_.B->batch_stride(), + gemm_workspace_.C->batch_stride(), + gemm_workspace_.Reference->batch_stride() + , ptr_Norm_constant + }; + + // Query host work space size + uint64_t host_workspace_size_needed = reference_op->get_host_workspace_size(&gemm_workspace_.configuration); + + std::vector host_workspace(host_workspace_size_needed); + + // Query device workspace size + uint64_t device_workspace_size_needed = reference_op->get_device_workspace_size(&gemm_workspace_.configuration); + // Initialize host and device workspaces + Status status = reference_op->initialize( + &gemm_workspace_.configuration, + host_workspace.data() + ); + + if (status != cutlass::Status::kSuccess) { + results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; + return true; + } + + // Run the operator + status = reference_op->run(&arguments, host_workspace.data()); + + results_.back().status = status; + + gemm_workspace_.Reference->copy_from_host(ptr_D); + gemm_workspace_.Reference_SFD->copy_from_host(ptr_SFD); + + // + // Verify results + // + auto resultD = compare_tensors( + options, + *gemm_workspace_.Computed, + *gemm_workspace_.Reference, + gemm_workspace_.Computed->batch_stride() + ); + + auto resultSFD = Disposition::kPassed; + if (gemm_desc.SFD.element != library::NumericTypeID::kVoid) { + resultSFD = compare_tensors( + options, + *gemm_workspace_.Computed_SFD, + *gemm_workspace_.Reference_SFD, + gemm_workspace_.Computed_SFD->batch_stride() + ); + } + + results_.back().verification_map[library::Provider::kReferenceHost] = resultD; + + if (resultSFD != Disposition::kPassed) { + results_.back().verification_map[library::Provider::kReferenceHost] = resultSFD; + } + + + // Save workspace if incorrect + if (options.verification.save_workspace == SaveWorkspace::kIncorrect && + results_.back().verification_map[library::Provider::kReferenceHost] == Disposition::kIncorrect) { + save_workspace( + device_context, + options, + gemm_desc, + library::Provider::kCUTLASS, + library::Provider::kReferenceHost); + } + + // Return true means continue profiling + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Measures performance results +bool BlockScaledGemmOperationProfiler::profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + + // Initialize structure containing GEMM arguments + gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.SFA = gemm_workspace_.SFA->data(); + gemm_workspace_.arguments.SFB = gemm_workspace_.SFB->data(); + gemm_workspace_.arguments.C = gemm_workspace_.C->data(); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); + gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } + + results_.back().status = profile_cutlass_( + results_.back(), + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data() + ); + } + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Method to profile a CUTLASS Operation +Status BlockScaledGemmOperationProfiler::profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace) { + + // initialize gemm underlying operation to handle parallel reduction + library::Operation const * underlying_operation = operation; + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { + return Status::kErrorNotSupported; + } + } + + auto func = [&](cudaStream_t, int iteration) { + // Iterate over copies of the problem in memory + int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; + + gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); + gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); + gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); + } + + Status status = underlying_operation->run( + arguments, + host_workspace, + device_workspace, + nullptr); + + if (status != Status::kSuccess) { + return status; + } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_.reduction_arguments, + gemm_workspace_.reduction_host_workspace.data(), + nullptr, + nullptr); + + if (status != Status::kSuccess) { + return status; + } + } + + return status; + }; + + return profile_kernel_(result, options, func); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index b407004674..efffefb741 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -38,6 +38,7 @@ // Profiler includes #include "cutlass/profiler/cutlass_profiler.h" #include "cutlass/profiler/gemm_operation_profiler.h" +#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h" #include "cutlass/profiler/rank_k_operation_profiler.h" #include "cutlass/profiler/rank_2k_operation_profiler.h" #include "cutlass/profiler/trmm_operation_profiler.h" @@ -60,6 +61,8 @@ CutlassProfiler::CutlassProfiler( operation_profilers_.emplace_back(new GemmOperationProfiler(options)); + operation_profilers_.emplace_back(new BlockScaledGemmOperationProfiler(options)); + operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options)); operation_profilers_.emplace_back(new Conv2dOperationProfiler(options)); diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index f06b9607cb..741a5f04af 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -616,6 +616,48 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + + case library::NumericTypeID::kFUE4M3: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFUE8M0: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE2M3: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE3M2: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE2M1: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kF64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -771,6 +813,50 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + + case library::NumericTypeID::kFUE4M3: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + + + case library::NumericTypeID::kFE2M3: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE3M2: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE2M1: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFUE8M0: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kF16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -990,6 +1076,50 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) { static_cast(dist.sequential.start) ); break; + + case library::NumericTypeID::kFUE4M3: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + + + case library::NumericTypeID::kFE2M3: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFE3M2: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFE2M1: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFUE8M0: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kF16: cutlass::reference::device::BlockFillSequential( reinterpret_cast(pointer_), @@ -1220,6 +1350,50 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) { static_cast(dist.sequential.start) ); break; + + case library::NumericTypeID::kFUE4M3: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + + + case library::NumericTypeID::kFE2M3: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFE3M2: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFE2M1: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFUE8M0: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kF16: cutlass::reference::host::BlockFillSequential( reinterpret_cast(host_data.data()), @@ -1516,6 +1690,34 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); + + case library::NumericTypeID::kFUE4M3: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + case library::NumericTypeID::kFUE8M0: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + case library::NumericTypeID::kFE2M3: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + + case library::NumericTypeID::kFE3M2: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + case library::NumericTypeID::kFE2M1: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + case library::NumericTypeID::kF16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -1684,6 +1886,46 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity, static_cast(epsilon), static_cast(nonzero_floor)); + + case library::NumericTypeID::kFUE4M3: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + case library::NumericTypeID::kFUE8M0: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kFE2M3: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kFE3M2: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kFE2M1: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + case library::NumericTypeID::kF16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), @@ -2026,6 +2268,27 @@ void DeviceAllocation::write_tensor_csv( case library::NumericTypeID::kFE5M2: write_tensor_csv_static_type(out, *this); break; + + case library::NumericTypeID::kFUE4M3: + write_tensor_csv_static_type(out, *this); + break; + + + case library::NumericTypeID::kFE2M3: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kFE3M2: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kFE2M1: + write_tensor_csv_static_type(out, *this); + break; + case library::NumericTypeID::kFUE8M0: + write_tensor_csv_static_type(out, *this); + break; + case library::NumericTypeID::kF16: write_tensor_csv_static_type(out, *this); break; @@ -2193,6 +2456,27 @@ void DeviceAllocation::fill_device(double val = 0.0) { case library::NumericTypeID::kFE5M2: tensor_fill(*this, static_cast(val)); break; + + case library::NumericTypeID::kFUE4M3: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kFUE8M0: + tensor_fill(*this, static_cast(val)); + break; + case library::NumericTypeID::kFE2M3: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kFE3M2: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kFE2M1: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kF16: tensor_fill(*this, static_cast(val)); break; @@ -2288,6 +2572,47 @@ void DeviceAllocation::fill_host(double val = 0.0) { std::vector host_data(bytes()); switch (this->type()) { + + case library::NumericTypeID::kFUE4M3: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kFUE8M0: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + case library::NumericTypeID::kFE2M3: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kFE3M2: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kFE2M1: + cutlass::reference::host::BlockFill( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(val) + ); + break; + + case library::NumericTypeID::kFE4M3: cutlass::reference::host::BlockFill( reinterpret_cast(host_data.data()), diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index b90b2ee19f..0ac618fda7 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -104,6 +104,25 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor( case library::NumericTypeID::kFE5M2: data_distribution.set_uniform(-1, 1, 0); break; + + case library::NumericTypeID::kFE2M3: + data_distribution.set_uniform(-2, 2, 0); + break; + case library::NumericTypeID::kFE3M2: + data_distribution.set_uniform(-2, 2, 0); + break; + case library::NumericTypeID::kFE2M1: + data_distribution.set_uniform(-2, 2, 0); + break; + case library::NumericTypeID::kFUE8M0: + data_distribution.set_uniform(1, 4, 0); + break; + + case library::NumericTypeID::kFUE4M3: + data_distribution.set_uniform(1, 4, 0); + break; + + case library::NumericTypeID::kF16: data_distribution.set_uniform(-3, 3, 0); break; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 80d346a018..fc2346d2b8 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -76,6 +76,8 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, {ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"}, + {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, + {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, {ArgumentTypeID::kInteger, {"use_pdl", "use-pdl"}, "Use PDL (true, false)"}, {ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"}, }, @@ -172,6 +174,38 @@ Status GemmOperationProfiler::GemmProblem::parse( this->k = 1024; } + + if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { + // default value + this->cluster_m = 1; + } + + if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { + // default value + this->cluster_n = 1; + } + + if (!arg_as_int(this->cluster_k, "cluster_k", problem_space, problem)) { + // default value + this->cluster_k = 1; + } + + if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { + // default value + this->cluster_m_fallback = 0; + } + + if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { + // default value + this->cluster_n_fallback = 0; + } + + if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { + // default value + this->cluster_k_fallback = 0; + } + + if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) { // default value this->use_pdl = false; @@ -192,6 +226,18 @@ Status GemmOperationProfiler::GemmProblem::parse( this->split_k_slices = 1; } + + if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) { + // default value + this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic; + } + + if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) { + // default value + this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic; + } + + if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { // default value this->batch_count = 1; @@ -338,6 +384,15 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "n", problem_space, n); set_argument(result, "k", problem_space, k); + + set_argument(result, "cluster_m", problem_space, cluster_m); + set_argument(result, "cluster_n", problem_space, cluster_n); + set_argument(result, "cluster_k", problem_space, cluster_k); + set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback); + set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback); + set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback); + + set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); set_argument(result, "split_k_slices", problem_space, split_k_slices); set_argument(result, "batch_count", problem_space, batch_count); @@ -345,6 +400,11 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "swizzle_size", problem_space, swizzle_size); set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl)); + + set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a)); + set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b)); + + set_argument(result, "alpha", problem_space, library::lexical_cast(alpha, operation_desc.element_epilogue)); @@ -388,6 +448,14 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_[i].configuration.problem_size.m() = int(problem_.m); gemm_workspace_[i].configuration.problem_size.n() = int(problem_.n); gemm_workspace_[i].configuration.problem_size.k() = int(problem_.k); + + gemm_workspace_[i].configuration.cluster_shape.m() = int(problem_.cluster_m); + gemm_workspace_[i].configuration.cluster_shape.n() = int(problem_.cluster_n); + gemm_workspace_[i].configuration.cluster_shape.k() = int(problem_.cluster_k); + gemm_workspace_[i].configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback); + gemm_workspace_[i].configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback); + gemm_workspace_[i].configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback); + gemm_workspace_[i].configuration.lda = problem_.lda; gemm_workspace_[i].configuration.ldb = problem_.ldb; gemm_workspace_[i].configuration.ldc = problem_.ldc; @@ -423,6 +491,15 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; gemm_workspace_[i].arguments.swizzle_size = problem_.swizzle_size; gemm_workspace_[i].arguments.raster_order = problem_.raster_order; + gemm_workspace_[i].arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_[i].arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + gemm_workspace_[i].arguments.split_k_slices = problem_.split_k_slices; + + + gemm_workspace_[i].arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a; + gemm_workspace_[i].arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b; + + initialize_result_(this->model_result_, options, operation_desc, problem_space); if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) { return can_implement; @@ -621,6 +698,9 @@ Status GemmOperationProfiler::initialize_workspace( if (options.execution_mode != ExecutionMode::kDryRun) { // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; + gemm_workspace_[i].arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_[i].arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + gemm_workspace_[i].arguments.split_k_slices = problem_.split_k_slices; gemm_workspace_[i].arguments.batch_count = problem_.batch_count; gemm_workspace_[i].arguments.lda = problem_.lda; gemm_workspace_[i].arguments.ldb = problem_.ldb; @@ -857,12 +937,32 @@ bool GemmOperationProfiler::verify_cutlass( } #endif // #if CUTLASS_ENABLE_CUBLAS + + cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.front().arguments.runtime_input_datatype_a; + cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.front().arguments.runtime_input_datatype_b; + + bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic; + bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic; + + assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static."); + + library::GemmDescription const &gemm_desc = static_cast(operation->description()); cutlass::library::NumericTypeID element_A = gemm_desc.A.element; cutlass::library::NumericTypeID element_B = gemm_desc.B.element; + + if (is_runtime_datatype_a) { + element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a); + } + + if (is_runtime_datatype_b) { + element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b); + } + + bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B); // Update disposition to worst case verification outcome among all @@ -1087,6 +1187,14 @@ bool GemmOperationProfiler::verify_with_reference_( gemm_workspace_[i].configuration.problem_size.m(), gemm_workspace_[i].configuration.problem_size.n(), gemm_workspace_[i].configuration.problem_size.k(), + + gemm_workspace_[i].configuration.cluster_shape.m(), + gemm_workspace_[i].configuration.cluster_shape.n(), + gemm_workspace_[i].configuration.cluster_shape.k(), + gemm_workspace_[i].configuration.cluster_shape_fallback.m(), + gemm_workspace_[i].configuration.cluster_shape_fallback.n(), + gemm_workspace_[i].configuration.cluster_shape_fallback.k(), + gemm_desc.tile_description.math_instruction.element_accumulator, gemm_desc.element_epilogue, diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index d11009ce82..5a518d7fbc 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -91,6 +91,11 @@ OperationProfiler::OperationProfiler( {ArgumentTypeID::kInteger, {"cluster_m", "cluster-shape::m"}, "Cluster shape in the M dimension"}, {ArgumentTypeID::kInteger, {"cluster_n", "cluster-shape::n"}, "Cluster shape in the N dimension"}, {ArgumentTypeID::kInteger, {"cluster_k", "cluster-shape::k"}, "Cluster shape in the K dimension"}, + + {ArgumentTypeID::kInteger, {"cluster_m_fallback", "cluster-shape-fallback::m"}, "Fallback Cluster shape in the M dimension"}, + {ArgumentTypeID::kInteger, {"cluster_n_fallback", "cluster-shape-fallback::n"}, "Fallback Cluster shape in the N dimension"}, + {ArgumentTypeID::kInteger, {"cluster_k_fallback", "cluster-shape-fallback::k"}, "Fallback Cluster shape in the K dimension"}, + {ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"}, {ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"}, {ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"}, @@ -174,6 +179,11 @@ bool OperationProfiler::satisfies( return false; } } + + bool dynamic_cluster = int64_t(op_desc.tile_description.cluster_shape.m()) == 0 || + int64_t(op_desc.tile_description.cluster_shape.n()) == 0 || + int64_t(op_desc.tile_description.cluster_shape.k()) == 0; + int64_t int_value; if (arg_as_int(int_value, "inst_m", problem_space, problem)) { @@ -212,6 +222,7 @@ bool OperationProfiler::satisfies( } } + if (!dynamic_cluster) { if (arg_as_int(int_value, "cluster_m", problem_space, problem)) { if (int64_t(op_desc.tile_description.cluster_shape.m()) != int_value) { return false; @@ -230,6 +241,7 @@ bool OperationProfiler::satisfies( } } + } if (arg_as_int(int_value, "stages", problem_space, problem)) { if (int64_t(op_desc.tile_description.threadblock_stages) != int_value) { return false; @@ -296,6 +308,11 @@ std::ostream& operator<<(std::ostream& out, library::OperationKind provider) { if (provider == library::OperationKind::kGemm) { out << "kGemm"; } + + else if (provider == library::OperationKind::kBlockScaledGemm) { + out << "kBlockScaledGemm"; + } + else if (provider == library::OperationKind::kRankK) { out << "kRankK"; } diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 4dd066fe3f..0adc2340b4 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -33,6 +33,7 @@ */ #include +#include #include #include "cutlass/cutlass.h" @@ -810,16 +811,27 @@ Options::Options(cutlass::CommandLine const &cmdline): } else if (cmdline.check_cmd_line_flag("kernels")) { cmdline.get_cmd_line_arguments("kernels", operation_names); - profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match"); - profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled"); + } + + if (cmdline.check_cmd_line_flag("kernels-file")) { + std::string filename; + cmdline.get_cmd_line_argument("kernels-file", filename, {}); + std::ifstream input(filename); + if (!input.good()) { + throw std::runtime_error("failed to open: " + filename); + } + for (std::string line; getline(input, line);) { + operation_names.push_back(line); + } } if (cmdline.check_cmd_line_flag("ignore-kernels")) { cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names); - profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match"); - profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled"); } + profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match"); + profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled"); + // Prevent launches on the device for anything other than CUTLASS operation // Allow verification only on host if (execution_mode == ExecutionMode::kTrace) { @@ -856,6 +868,11 @@ void Options::print_usage(std::ostream &out) const { << " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line << " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n" + << " --kernels-file= " + << " Same behavior as --kernels, but kernel names are specified in a file" << end_of_line + << " with one kernel on each line. Set of profiled kernels is the union of kernels specified" << end_of_line + << " here and those specified in `kernels`.\n\n" + << " --ignore-kernels= " << " Excludes kernels whose names match anything in this list.\n\n" ; diff --git a/tools/profiler/src/problem_space.cpp b/tools/profiler/src/problem_space.cpp index ced00009bd..0d8ade05c2 100644 --- a/tools/profiler/src/problem_space.cpp +++ b/tools/profiler/src/problem_space.cpp @@ -879,6 +879,32 @@ bool arg_as_NumericTypeID( ///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + KernelArgument::Value const *value_ptr) { + + if (value_ptr->not_null) { + if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) { + + runtime_datatype = library::from_string( + static_cast(value_ptr)->element); + if (runtime_datatype == library::RuntimeDatatype::kInvalid) { + throw std::runtime_error( + "arg_as_RuntimeDatatype() - illegal cast."); + } + } + else { + throw std::runtime_error( + "arg_as_RuntimeDatatype() - illegal cast."); + } + return true; + } + return false; +} + + /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_RasterOrder( library::RasterOrder &raster_order, @@ -945,6 +971,21 @@ bool arg_as_LayoutTypeID( return false; } + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + size_t idx = problem_space.argument_index(name); + KernelArgument::Value const *value_ptr = problem.at(idx).get(); + + return arg_as_RuntimeDatatype(runtime_datatype, value_ptr); +} + + /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_LayoutTypeID( library::LayoutTypeID &layout_type, diff --git a/tools/util/include/cutlass/util/reference/device/convolution.h b/tools/util/include/cutlass/util/reference/device/convolution.h index 552a7a2e19..7c6f803c47 100644 --- a/tools/util/include/cutlass/util/reference/device/convolution.h +++ b/tools/util/include/cutlass/util/reference/device/convolution.h @@ -922,7 +922,7 @@ __global__ void Conv3dWgrad( filter_s = problem_size.S - 1 - filter_s; } - int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d; + int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; diff --git a/tools/util/include/cutlass/util/reference/host/gemm.h b/tools/util/include/cutlass/util/reference/host/gemm.h index dc5a2be6dd..2afee7b36d 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm.h +++ b/tools/util/include/cutlass/util/reference/host/gemm.h @@ -352,7 +352,7 @@ struct Gemm>( + ScalarType, ComputeType, xor_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); } @@ -367,7 +367,7 @@ struct Gemm>( + ScalarType, ComputeType, xor_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); } }; @@ -389,7 +389,7 @@ struct Gemm>( + ScalarType, ComputeType, and_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); } @@ -404,7 +404,7 @@ struct Gemm>( + ScalarType, ComputeType, and_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); } }; diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 98ad45e937..534f546209 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -42,6 +42,7 @@ #include "cutlass/relatively_equal.h" #include "cute/tensor.hpp" +#include "cute/pointer.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -59,10 +60,20 @@ struct ElementTraits struct GettMainloopParams { using ElementAccumulator = ElementAccumulator_; @@ -79,23 +90,105 @@ struct GettMainloopParams { ComplexTransform transform_A = ComplexTransform::kNone; ComplexTransform transform_B = ComplexTransform::kNone; + + using TensorSfA = TensorSfA_; + using TensorSfB = TensorSfB_; + using EngineSfA = typename TensorSfA::engine_type; + using LayoutSfA = typename TensorSfA::layout_type; + using EngineSfB = typename TensorSfB::engine_type; + using LayoutSfB = typename TensorSfB::layout_type; + TensorSfA_ SfA{}; + TensorSfB_ SfB{}; + + + GettMainloopParams() {} + + GettMainloopParams(TensorA tensor_A, TensorB tensor_B) + : A(tensor_A), B(tensor_B) {} + + + GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : A(tensor_A), SfA(tensor_SfA), + B(tensor_B), SfB(tensor_SfB) {} + + +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorSfA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorSfB_ // (N, K, L) +> +struct GettBlockScalingMainloopParams : public GettMainloopParams { + using Base = GettMainloopParams; + using ElementAccumulator = typename Base::ElementAccumulator; + using TensorA = typename Base::TensorA; + using TensorB = typename Base::TensorB; + using EngineA = typename Base::EngineA; + using LayoutA = typename Base::LayoutA; + using EngineB = typename Base::EngineB; + using LayoutB = typename Base::LayoutB; + ComplexTransform transform_A = Base::transform_A; + ComplexTransform transform_B = Base::transform_B; + + using TensorSfA = typename Base::TensorSfA; + using TensorSfB = typename Base::TensorSfB; + using EngineSfA = typename Base::EngineSfA; + using LayoutSfA = typename Base::LayoutSfA; + using EngineSfB = typename Base::EngineSfB; + using LayoutSfB = typename Base::LayoutSfB; + + GettBlockScalingMainloopParams() {} + + GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {} + + }; + ///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class SfStrategy { + None = 0, + SfDGen = 1 +}; + + +/////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters +// +/////////////////////////////////////////////////////////// + template< class ElementScalar_, class ElementScalingFactor_, class ElementAccumulator_, class ElementCompute_, - class TensorC_, // (M, N, L) - class TensorD_, // (M, N, L) - class VectorBias_ = TensorD_, // (M, 1) - class TensorAux_ = TensorD_, // (M, N, L) - class VectorAlpha_ = TensorD_, // (M, 1) - class VectorBeta_ = VectorAlpha_, // (M, 1) + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class TensorAux_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, N, L) + class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class TensorSFD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, class BiasBinaryOp_ = cutlass::plus, bool PerColumnBias_ = false + , + SfStrategy SfGenStrategy_ = SfStrategy::None > struct GettEpilogueParams { using ElementScalar = ElementScalar_; @@ -108,6 +201,8 @@ struct GettEpilogueParams { using VectorBias = VectorBias_; using VectorAlpha = VectorAlpha_; using VectorBeta = VectorBeta_; + using TensorSFD = TensorSFD_; + using SFD_VectorSize = SFD_VectorSize_; using ActivationFunctor = ActivationFunctor_; using BiasBinaryOp = BiasBinaryOp_; @@ -115,7 +210,11 @@ struct GettEpilogueParams { using LayoutC = typename TensorC::layout_type; using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; + using EngineSfD = typename TensorSFD::engine_type; + using LayoutSfD = typename TensorSFD::layout_type; static constexpr bool PerColumnBias = PerColumnBias_; + static constexpr SfStrategy SfGenStrategy = SfGenStrategy_; + ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); @@ -125,7 +224,8 @@ struct GettEpilogueParams { TensorAux Aux{}; VectorAlpha Valpha{}; VectorBeta Vbeta{}; - ElementCompute st = ElementCompute(1); + TensorSFD SfD{}; + ElementCompute st = ElementCompute(1); ElementAccumulator* abs_max_D = nullptr; ElementAccumulator* abs_max_Aux = nullptr; @@ -137,8 +237,250 @@ struct GettEpilogueParams { ElementScalingFactor scale_aux = ElementScalingFactor(1); bool beta_per_channel_scaling = false; + GettEpilogueParams() {} + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {} + + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {} + + + GettEpilogueParams( + ElementScalar alpha, ElementScalar beta, + TensorC tensor_C, TensorD tensor_D, + VectorBias bias, TensorAux tensor_aux, + VectorAlpha vector_alpha, VectorBeta vector_beta) + : alpha(alpha), beta(beta), + C(tensor_C), D(tensor_D), + Bias(bias), Aux(tensor_aux), + Valpha(vector_alpha), Vbeta(vector_beta) {} }; + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, + class TensorD_, + class TensorSfD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettBlockScalingEpilogueParams : public GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // class BiasBinaryOp_ = + false, //PerColumnBias_ + SfGenStrategy_ // SfGenStrategy + > { + using Base = GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // BiasBinaryOp + false, // PerColumnBias + SfGenStrategy_ // SfGenStrategy + >; + using ElementScalar = typename Base::ElementScalar; + using ElementScalingFactor = typename Base::ElementScalingFactor; + using ElementAccumulator = typename Base::ElementAccumulator; + using ElementCompute = typename Base::ElementCompute; + using TensorC = typename Base::TensorC; + using TensorD = typename Base::TensorD; + using TensorAux = typename Base::TensorAux; + using VectorBias = typename Base::VectorBias; + using VectorAlpha = typename Base::VectorAlpha; + using VectorBeta = typename Base::VectorBeta; + using TensorSFD = typename Base::TensorSFD; + using SFD_VectorSize = typename Base::SFD_VectorSize; + using ActivationFunctor = typename Base::ActivationFunctor; + using BiasBinaryOp = typename Base::BiasBinaryOp; + + using EngineC = typename Base::EngineC; + using LayoutC = typename Base::LayoutC; + using EngineD = typename Base::EngineD; + using LayoutD = typename Base::LayoutD; + using EngineSfD = typename Base::EngineSfD; + using LayoutSfD = typename Base::LayoutSfD; + static constexpr bool PerColumnBias = Base::PerColumnBias; + static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy; + + GettBlockScalingEpilogueParams() {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : Base(alpha, beta, tensor_C, tensor_D) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {} +}; + + + + + +/////////////////////////////////////////////////////////// +// +// Generic Gett 3x Implementation +// +/////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +void compute_1d_scaling_factor_and_quantized_output( + EpilogueParams const& epilogue_params, + TensorD &tensor_D, + TensorSFD &tensor_SfD, + int64_t m, + int64_t n, + int64_t l, + ElementCompute (&acc)[kBlockM][kBlockN]) +{ + using ElementD = typename ElementTraits::type; + using ElementSfD = typename ElementTraits::type; + + int const M = cute::size<0>(tensor_D.layout()); + int const N = cute::size<1>(tensor_D.layout()); + int const L = cute::size<2>(tensor_D.layout()); + + auto mul = cutlass::multiplies{}; + auto div = divides{}; + // Get FP max + ElementCompute fp_max = ElementCompute(std::numeric_limits::max()); + float scale_down_factor = div(1.0f, fp_max); + // Get st' = st / FP max + ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor); + + absolute_value_op abs_op; + maximum_with_nan_propogation max_op; + + if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) { + // MN major output + int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); + // Col major output + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t col = n + n_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_row = m + kVectorSize * v_b; + if (sf_row < M && col < N) { + tensor_SfD(sf_row, col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale); + } + } + } + } + } + else { + int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize); + // row major output + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t row = m + m_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_col = n + kVectorSize * v_b; + + if (row < M && sf_col < N) { + tensor_SfD(row, sf_col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale); + } + } + } + } + } +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// /// GETT - General Tensor-Tensor contraction reference kernel @@ -188,6 +530,11 @@ void gett_mainloop( using ElementA = typename ElementTraits::type; using ElementB = typename ElementTraits::type; + + using ElementSFA = typename ElementTraits::type; + using ElementSFB = typename ElementTraits::type; + + using RingOp = multiply_add; RingOp fma_op; @@ -207,6 +554,14 @@ void gett_mainloop( // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + + if constexpr (not cute::is_same_v){ + // Load SFA + auto sfa = static_cast(mainloop_params.SfA(m + m_b, k, l)); + a_frag[m_b] *= sfa; + } + + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { a_frag[m_b] = conj(a_frag[m_b]); } @@ -222,6 +577,14 @@ void gett_mainloop( // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + + if constexpr (not cute::is_same_v){ + // Load SFB + auto sfb = static_cast(mainloop_params.SfB(n + n_b, k, l)); + b_frag[n_b] *= sfb; + } + + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { b_frag[n_b] = conj(b_frag[n_b]); } @@ -259,6 +622,7 @@ void gett_epilogue( using ElementCompute = typename EpilogueParams::ElementCompute; using ElementC = typename EpilogueParams::TensorC::value_type; using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementSfD = typename EpilogueParams::TensorSFD::value_type; using ElementAux = typename EpilogueParams::TensorAux::value_type; using ElementBias = typename EpilogueParams::VectorBias::value_type; using ElementScalar = typename EpilogueParams::ElementScalar; @@ -267,6 +631,8 @@ void gett_epilogue( using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy; + constexpr bool IsScalingAndAmaxOutputNeeded = cute::is_same_v or cute::is_same_v; @@ -412,6 +778,17 @@ void gett_epilogue( } } } // m_b + + if constexpr ( + SfGenStrategy == SfStrategy::SfDGen + ) { + // 1d scale factor generation + constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{}; + if (epilogue_params.SfD.data() != nullptr) { + compute_1d_scaling_factor_and_quantized_output(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum); + } + } + for (int m_b = 0; m_b < kBlockM; ++m_b) { for (int n_b = 0; n_b < kBlockN; ++n_b) { if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {