Skip to content

Commit

Permalink
Support parallel split K mode for porfiling
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Han <[email protected]>
  • Loading branch information
Peter9606 committed Jun 11, 2021
1 parent 6a10640 commit 3584270
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 5 deletions.
4 changes: 3 additions & 1 deletion tools/library/src/gemm_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ class GemmUniversalOperation : public GemmOperationBase<Operator_> {
operator_args.ldb = int(configuration->ldb);
operator_args.ldc = int(configuration->ldc);
operator_args.ldd = int(configuration->ldd);


operator_args.batch_stride_D = int(configuration->problem_size.m()) * int(configuration->problem_size.n());

return Status::kSuccess;
}

Expand Down
6 changes: 3 additions & 3 deletions tools/library/src/reduction/reduction_device.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
1,
ElementAccumulator,
ElementCompute
>;
Expand Down Expand Up @@ -81,7 +81,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) {

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
1,
ElementAccumulator,
ElementCompute
>;
Expand Down Expand Up @@ -115,7 +115,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
1,
ElementAccumulator,
ElementCompute
>;
Expand Down
169 changes: 168 additions & 1 deletion tools/profiler/src/gemm_operation_profiler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "gemm_operation_profiler.h"
#include "gpu_timer.h"

#include "cutlass/library/singleton.h"
#include "cutlass/library/library.h"
#include "cutlass/library/handle.h"

Expand All @@ -55,6 +56,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
library::OperationKind::kGemm,
{
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
Expand Down Expand Up @@ -91,6 +93,9 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const {
<< "Profile a particular problem size:\n"
<< " $ cutlass_profiler --operation=Gemm --m=1024 --n=1024 --k=128\n\n"

<< "Profile a particular problem size with split K and paralell reduction:\n"
<< " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --m=1024 --n=1024 --k=128\n\n"

<< "Schmoo over problem size and beta:\n"
<< " $ cutlass_profiler --operation=Gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n"

Expand Down Expand Up @@ -155,7 +160,12 @@ Status GemmOperationProfiler::GemmProblem::parse(
// default value
this->k = 1024;
}


if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
// defualt value
this->split_k_mode = library::SplitKMode::kSerial;
}

this->mode = library::GemmUniversalMode::kGemm;
if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) {
// default value
Expand All @@ -170,6 +180,10 @@ Status GemmOperationProfiler::GemmProblem::parse(
this->mode = library::GemmUniversalMode::kBatched;
}

if(this->split_k_mode == library::SplitKMode::kParallel) {
this->mode = library::GemmUniversalMode::kGemmSplitKParallel;
}

if (this->split_k_slices > 1 && this->batch_count > 1) {
// At least one of these must be one
return Status::kErrorInvalidProblem;
Expand Down Expand Up @@ -275,6 +289,8 @@ void GemmOperationProfiler::GemmProblem::initialize_result(

set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));

set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));

set_argument(result, "A", problem_space,
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));

Expand Down Expand Up @@ -346,6 +362,13 @@ Status GemmOperationProfiler::initialize_configuration(
gemm_workspace_.arguments.beta = problem_.beta.data();
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;

// initialize reduction operation for parallel splitKMode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
if (!initialize_reduction_configuration_(operation, problem)) {
return Status::kErrorInternal;
}
}

initialize_result_(this->model_result_, options, operation_desc, problem_space);

return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
Expand Down Expand Up @@ -373,6 +396,52 @@ void GemmOperationProfiler::initialize_result_(

}

/// Initialize redution problem dimentions and library::Operation
bool GemmOperationProfiler::initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem) {

library::GemmDescription const &gemm_desc =
static_cast<library::GemmDescription const&>(operation->description());

if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) {
return false;
}

if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) {
return false;
}

/// initialize library::ReductionConfiguration
gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.m), int(problem_.n), int(problem_.k)).mn();
gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices);
gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.m), int(problem_.n), int(problem_.k)).mn().product();
gemm_workspace_.reduction_configuration.ldw = problem_.ldc;
gemm_workspace_.reduction_configuration.lds = problem_.ldc;
gemm_workspace_.reduction_configuration.ldd = problem_.ldc;

// find reduction operation
library::ReductionFunctionalKey reduction_key(
library::Provider::kCUTLASS,
gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace
gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator
gemm_desc.C.element, // element output
gemm_desc.element_epilogue // element coumpute
);

auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key);

if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) {
return false;
}

// initialize reduction operation required for parallel split-k operator
reduction_op_ = reduction_it->second;

// reduction operation found and initialized
return true;
}

/// Initializes workspace
Status GemmOperationProfiler::initialize_workspace(
Options const &options,
Expand Down Expand Up @@ -473,6 +542,24 @@ Status GemmOperationProfiler::initialize_workspace(
&gemm_workspace_.configuration,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());

if (status != Status::kSuccess) {
return status;
}

if (problem_.split_k_mode == library::SplitKMode::kParallel) {
workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration);
gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0);

status = reduction_op_->initialize(
&gemm_workspace_.reduction_configuration,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);

if (status != Status::kSuccess) {
return status;
}
}
}

//
Expand Down Expand Up @@ -523,6 +610,19 @@ bool GemmOperationProfiler::verify_cutlass(
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();

if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
gemm_workspace_.arguments.beta = problem_.beta_zero.data();

gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
}

//
// Run the CUTLASS operation
//
Expand All @@ -537,6 +637,19 @@ bool GemmOperationProfiler::verify_cutlass(
return false;
}

// Run parallel reduction kernel for parallel split_k_mode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
results_.back().status = reduction_op_->run(
&gemm_workspace_.reduction_arguments,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);

if (results_.back().status != Status::kSuccess) {
results_.back().disposition = Disposition::kFailed;
return false;
}
}

cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
results_.back().disposition = Disposition::kFailed;
Expand Down Expand Up @@ -892,6 +1005,19 @@ bool GemmOperationProfiler::profile(
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();

if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
gemm_workspace_.arguments.beta = problem_.beta_zero.data();

gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
}

results_.back().status = profile_cutlass_(
results_.back().runtime,
options,
Expand Down Expand Up @@ -938,6 +1064,14 @@ Status GemmOperationProfiler::profile_cutlass_(
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);

if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();

gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
}

// Execute the CUTLASS operation
status = operation->run(
&gemm_workspace_.arguments,
Expand All @@ -947,6 +1081,18 @@ Status GemmOperationProfiler::profile_cutlass_(
if (status != Status::kSuccess) {
return status;
}

// Run parallel reduction kernel for parallel split_k_mode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
status = reduction_op_->run(
&gemm_workspace_.reduction_arguments,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);

if (status != Status::kSuccess) {
return status;
}
}
}

//
Expand All @@ -973,6 +1119,14 @@ Status GemmOperationProfiler::profile_cutlass_(
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);

if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();

gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
}

status = operation->run(
arguments,
host_workspace,
Expand All @@ -981,6 +1135,19 @@ Status GemmOperationProfiler::profile_cutlass_(
if (status != Status::kSuccess) {
return status;
}

// Run parallel reduction kernel for parallel split_k_mode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
status = reduction_op_->run(
&gemm_workspace_.reduction_arguments,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);

if (status != Status::kSuccess) {
return status;
}
}

}

//
Expand Down
21 changes: 21 additions & 0 deletions tools/profiler/src/gemm_operation_profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "operation_profiler.h"
#include "performance_result.h"
#include "problem_space.h"
#include "reduction_operation_profiler.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -61,6 +62,7 @@ class GemmOperationProfiler : public OperationProfiler {
struct GemmProblem {

cutlass::library::GemmUniversalMode mode;
cutlass::library::SplitKMode split_k_mode;
int64_t m;
int64_t n;
int64_t k;
Expand All @@ -72,6 +74,12 @@ class GemmOperationProfiler : public OperationProfiler {
int split_k_slices;
int batch_count;

// gemm with parallel interleaved reduction
// gemm epilogue (alpha, beta) = (1.0, 0.0)
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
std::vector<uint8_t> alpha_one;
std::vector<uint8_t> beta_zero;

//
// Methods
//
Expand Down Expand Up @@ -121,6 +129,13 @@ class GemmOperationProfiler : public OperationProfiler {
/// Buffer used for the operations' device workspace
DeviceAllocation device_workspace;

/// Library configuration and arguments for reduction operator
library::ReductionConfiguration reduction_configuration;
library::ReductionArguments reduction_arguments;

/// Buffer used for the cutlass reduction operations' host workspace
std::vector<uint8_t> reduction_host_workspace;

//
// Methods
//
Expand All @@ -141,6 +156,8 @@ class GemmOperationProfiler : public OperationProfiler {
/// Device memory allocations
GemmWorkspace gemm_workspace_;

/// CUTLASS parallel reduction operation to follow this* gemm operation
library::Operation const *reduction_op_;

public:
//
Expand Down Expand Up @@ -231,6 +248,10 @@ class GemmOperationProfiler : public OperationProfiler {
void *host_workspace,
void *device_workspace);

/// Initialize reduction problem dimensions and library::Operation
bool initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem);
};

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 3584270

Please sign in to comment.