Skip to content

Commit

Permalink
remove overlap in device architectures for cutlass_scaled_mm
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Oct 2, 2024
1 parent 475e57a commit f9f1dc3
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
42 changes: 31 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")

set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
Expand Down Expand Up @@ -268,26 +267,47 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
# Only build scaled_mm_c3x if we are building for something compatible with sm90a
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)

# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only works on Hoppper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set(SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures or CUDA Compiler version is "
"not >= 12.0")
endif()

#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures (or archs are already built "
"for 3x)")
endif()


#
# Machete kernels
Expand Down
70 changes: 48 additions & 22 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,

at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Hopper

// Guard against compilation issues for sm90 kernels
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
} else if (version_num == 89) {

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) {
return;
}

if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}

// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
#endif

TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
Expand Down Expand Up @@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype ", c.dtype());

at::cuda::OptionalCUDAGuard const device_guard(device_of(a));

int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#else
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
} else if (version_num == 89) {

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else if (version_num >= 80) {
return;
}

if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}

// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif

TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}

0 comments on commit f9f1dc3

Please sign in to comment.