Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add epilogue to store MMA results in shared memory before write to #387

Merged
merged 36 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
379eab5
add epilogue to store MMA results in shared memory before write to
liqiangxl May 22, 2023
b500d36
revise test
liqiangxl Jun 8, 2023
acf1167
format
liqiangxl Jun 8, 2023
be885b0
swizzleSharedMemory
liqiangxl Jun 20, 2023
7138cbe
format
liqiangxl Jun 20, 2023
fad4ad8
fix failed test cases
liqiangxl Jun 21, 2023
84e3e98
propagate to epilogue tensors
liqiangxl Jun 28, 2023
a94e5df
check num_shared_mem_tensors
liqiangxl Jun 28, 2023
9f4bcc4
format
liqiangxl Jun 28, 2023
1760c15
disable_smem_epilogue
liqiangxl Jun 28, 2023
f0ff6f9
extend MatmulSASSTest
liqiangxl Jun 28, 2023
da5dc3a
schedule output tensor
liqiangxl Jun 29, 2023
537b855
wip
liqiangxl Jun 29, 2023
ab86a1f
use propagate
liqiangxl Jul 1, 2023
f2a75cd
fix failed case
liqiangxl Jul 3, 2023
7a4d5b5
fix ci fails by increasing tolerance:x
liqiangxl Jul 4, 2023
86b8911
merge main
liqiangxl Jul 7, 2023
5586b3a
fix failed cases
liqiangxl Jul 7, 2023
6a8f139
trivial fix
liqiangxl Jul 7, 2023
d6212cb
format
liqiangxl Jul 9, 2023
95ea553
revise hasEnoughSharedMemoryForEpilogue
liqiangxl Jul 9, 2023
925e04d
merge main
liqiangxl Jul 9, 2023
1f30a36
wip
liqiangxl Jul 10, 2023
80d7588
cacheAfter mma_result
liqiangxl Jul 10, 2023
d3019f0
add epilogue cast and relu tests
liqiangxl Jul 10, 2023
212258c
trivial fix
liqiangxl Jul 10, 2023
a2045cd
mma data types
liqiangxl Jul 12, 2023
32f43d8
merge main
liqiangxl Jul 12, 2023
67ecdb0
revise smem swizzle
liqiangxl Jul 13, 2023
864a918
test with revised swizzle
liqiangxl Jul 13, 2023
79c452c
merge main
liqiangxl Jul 14, 2023
26d970d
save file
liqiangxl Jul 14, 2023
189cef4
revise based on review comments
liqiangxl Jul 17, 2023
889ce23
rename
liqiangxl Jul 17, 2023
94aae9b
change default to false
liqiangxl Jul 17, 2023
27870d0
Merge branch 'main' into llu/matmul_epilogue
liqiangxl Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 235 additions & 74 deletions csrc/scheduler/matmul.cpp

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class MatmulParams : public HeuristicParams {
//! C3 C4 D3 D4
int grid_swizzle_factor = 1;

//! Unswizzle MMA results in shared memory to get
//! coalesced write to global memory
bool use_smem_epilogue = false;

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Matmul Parameters ========\n"
Expand All @@ -112,6 +116,7 @@ class MatmulParams : public HeuristicParams {
: "column-major")
<< "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n"
<< "Use shared memory epilogue: " << use_smem_epilogue << "\n"
<< "====================================\n";
return ss.str();
}
Expand Down
26 changes: 26 additions & 0 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ inline bool initExtraHeuristics(
return true;
}

//! A wrapper to get MMA Tensor data types
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D
inline mma_utils::MmaDataTypes getMmaDataTypes(
const std::map<MatmulRole, std::vector<TensorView*>>& roles_map) {
auto getMMADataType = [&](MatmulRole role) {
auto entry = roles_map.find(role);
if (entry != roles_map.end() && !entry->second.empty()) {
return entry->second.front()->dtype();
}
TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!");
};
const auto a_type = getMMADataType(MatmulRole::INPUT_A);
const auto b_type = getMMADataType(MatmulRole::INPUT_B);
const auto c_type = getMMADataType(MatmulRole::OUTPUT_D);
return mma_utils::MmaDataTypes{a_type, b_type, c_type};
}

//! A helper for getting problem shape from fusion and runtime info.
ProblemShape getProblemShape(
Fusion* fusion,
Expand Down Expand Up @@ -398,6 +415,15 @@ std::shared_ptr<MatmulParams> getMatmulHeuristics(
// Disable magic zero for matmul kernels
params->cparams.enable_magic_zero = false;

// Set whether to use shared memory for epilogue
const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion);
TORCH_INTERNAL_ASSERT(
roles_map_opt.isValid(), "Tensor roles map in mma is not valid.");
params->use_smem_epilogue = mma_utils::generateSharedMemoryEpilogueHeuristics(
params->tile_sizes,
params->double_buffer_options.smem_double_buffer_stage,
getMmaDataTypes(roles_map_opt.getData()));

if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) {
printMsg(params->toString());
}
Expand Down
53 changes: 43 additions & 10 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on

#include <ATen/cuda/CUDAContext.h>
#include <device_lower/utils.h>
#include <expr_evaluator.h>
#include <ir/printer.h>
Expand All @@ -14,11 +15,49 @@
#include <scheduler/utils.h>
#include <variant>
#include "mma_type.h"

namespace nvfuser {

namespace mma_utils {

bool generateSharedMemoryEpilogueHeuristics(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage,
const MmaDataTypes& data_types) {
const auto properties = at::cuda::getCurrentDeviceProperties();
const size_t device_smem_limit = properties->sharedMemPerBlockOptin;

auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile;
const auto threads_per_block =
warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize;
// a thread can use up to 255 registers, blocks per sm is limited by available
// registers
const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255);
const auto blocks_per_sm_by_register = threads_per_sm / threads_per_block;
// see scheduleContiguousVectorLoad
const int vector_word = 8;
const int round_to_factor = warp_dims.m * warp_dims.n * warp_dims.k *
properties->warpSize * vector_word;
const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k;
const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k;
const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) *
round_to_factor * smem_double_buffer_stage) *
dataTypeSize(data_types[0]);
const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) *
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
round_to_factor * smem_double_buffer_stage) *
dataTypeSize(data_types[1]);
const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) *
dataTypeSize(data_types[2]);

// use additional shared memory for epilogue if blocks per sm is not changed
const auto blocks_per_sm_without_smem_epilogue = std::min(
device_smem_limit / (smem_a + smem_b), (size_t)blocks_per_sm_by_register);
const auto blocks_per_sm_with_smem_epilogue = std::min(
device_smem_limit / (smem_a + smem_b + smem_c),
(size_t)blocks_per_sm_by_register);
return blocks_per_sm_with_smem_epilogue ==
blocks_per_sm_without_smem_epilogue;
}

void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
// Assumes
// [M, N, K]
Expand Down Expand Up @@ -379,11 +418,6 @@ bool canValidateIsInnerDim(
if (!split->factor()->isConstInt()) {
return false;
}
if (split->factor()->evaluateInt() < inner_dim_size) {
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
// This might be too restrictive. Would need more
// bookkeeping to relax.
return false;
}
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
leaf = split->in();
} else if (auto merge = dynamic_cast<Merge*>(expr)) {
// Might consider just rejecting merge.
Expand All @@ -396,9 +430,6 @@ bool canValidateIsInnerDim(
if (!leaf->extent()->isConstInt()) {
return false;
}
if (leaf->extent()->evaluateInt() != inner_dim_size) {
return false;
}
leaf = merge->inner();
} else {
// No support for swizzled inner dim for now.
Expand Down Expand Up @@ -438,7 +469,9 @@ void checkDimSize(
":",
id->extent()->evaluateInt(),
"vs",
expect[axis_index]);
expect[axis_index],
"\n for tv: ",
tv->toString());
}
}

Expand Down
13 changes: 13 additions & 0 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ using ProblemIterDomains = std::array<IterDomain*, 3>;
//! a single tv, for example input for beta scaling in epilogue
using RolesMap = std::map<MatmulRole, std::vector<TensorView*>>;

//! An alias for storing data types of the tensors in the mma op
//! the order is INPUT_A, INPUT_B, OUTPUT_D
using MmaDataTypes = std::array<DataType, 3>;

//! A wrapper for data containers with optional error message stored if
//! initialization of the data fails.
template <typename DataType>
Expand Down Expand Up @@ -289,6 +293,15 @@ TORCH_CUDA_CU_API ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion);
//! be gathered.
TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion);

//! Return whether use shared memory epilogue or not.
//! Returns true if using shared memory epilogue won't cause
//! the decrease of occupancy ratio. The occupancy ratio is
//! estimated using register and shared memory usage.
TORCH_CUDA_CU_API bool generateSharedMemoryEpilogueHeuristics(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage,
const MmaDataTypes& data_types);

} // namespace mma_utils

} // namespace nvfuser
Loading