Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add epilogue to store MMA results in shared memory before write to #387

Merged
merged 36 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
379eab5
add epilogue to store MMA results in shared memory before write to
liqiangxl May 22, 2023
b500d36
revise test
liqiangxl Jun 8, 2023
acf1167
format
liqiangxl Jun 8, 2023
be885b0
swizzleSharedMemory
liqiangxl Jun 20, 2023
7138cbe
format
liqiangxl Jun 20, 2023
fad4ad8
fix failed test cases
liqiangxl Jun 21, 2023
84e3e98
propagate to epilogue tensors
liqiangxl Jun 28, 2023
a94e5df
check num_shared_mem_tensors
liqiangxl Jun 28, 2023
9f4bcc4
format
liqiangxl Jun 28, 2023
1760c15
disable_smem_epilogue
liqiangxl Jun 28, 2023
f0ff6f9
extend MatmulSASSTest
liqiangxl Jun 28, 2023
da5dc3a
schedule output tensor
liqiangxl Jun 29, 2023
537b855
wip
liqiangxl Jun 29, 2023
ab86a1f
use propagate
liqiangxl Jul 1, 2023
f2a75cd
fix failed case
liqiangxl Jul 3, 2023
7a4d5b5
fix ci fails by increasing tolerance:x
liqiangxl Jul 4, 2023
86b8911
merge main
liqiangxl Jul 7, 2023
5586b3a
fix failed cases
liqiangxl Jul 7, 2023
6a8f139
trivial fix
liqiangxl Jul 7, 2023
d6212cb
format
liqiangxl Jul 9, 2023
95ea553
revise hasEnoughSharedMemoryForEpilogue
liqiangxl Jul 9, 2023
925e04d
merge main
liqiangxl Jul 9, 2023
1f30a36
wip
liqiangxl Jul 10, 2023
80d7588
cacheAfter mma_result
liqiangxl Jul 10, 2023
d3019f0
add epilogue cast and relu tests
liqiangxl Jul 10, 2023
212258c
trivial fix
liqiangxl Jul 10, 2023
a2045cd
mma data types
liqiangxl Jul 12, 2023
32f43d8
merge main
liqiangxl Jul 12, 2023
67ecdb0
revise smem swizzle
liqiangxl Jul 13, 2023
864a918
test with revised swizzle
liqiangxl Jul 13, 2023
79c452c
merge main
liqiangxl Jul 14, 2023
26d970d
save file
liqiangxl Jul 14, 2023
189cef4
revise based on review comments
liqiangxl Jul 17, 2023
889ce23
rename
liqiangxl Jul 17, 2023
94aae9b
change default to false
liqiangxl Jul 17, 2023
27870d0
Merge branch 'main' into llu/matmul_epilogue
liqiangxl Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 266 additions & 91 deletions csrc/scheduler/matmul.cpp

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class MatmulParams : public HeuristicParams {
//! C3 C4 D3 D4
int grid_swizzle_factor = 1;

//! Unswizzle MMA results in shared memory to get
//! coalesced write to global memory
bool has_smem_epilogue = true;

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Matmul Parameters ========\n"
Expand All @@ -112,6 +116,7 @@ class MatmulParams : public HeuristicParams {
: "column-major")
<< "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n"
<< "Use shared memory epilogue: " << has_smem_epilogue << "\n"
<< "====================================\n";
return ss.str();
}
Expand Down
34 changes: 34 additions & 0 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ inline bool initExtraHeuristics(
return true;
}

//! A wrapper to get MMA Tensor data types
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
//! The order of returned types: INPUT_A, INPUT_B, OUTPUT_D
inline mma_utils::MmaDataTypes getMmaDataTypes(
const std::map<MatmulRole, std::vector<TensorView*>>& roles_map) {
auto getMMADataType = [&](MatmulRole role) {
auto entry = roles_map.find(role);
if (entry != roles_map.end() && !entry->second.empty()) {
return entry->second.front()->dtype();
}
TORCH_INTERNAL_ASSERT(false, "Get MMA Tensor data type failed!");
};
const auto a_type = getMMADataType(MatmulRole::INPUT_A);
const auto b_type = getMMADataType(MatmulRole::INPUT_B);
const auto c_type = getMMADataType(MatmulRole::OUTPUT_D);
return mma_utils::MmaDataTypes{a_type, b_type, c_type};
}

//! A helper for getting problem shape from fusion and runtime info.
ProblemShape getProblemShape(
Fusion* fusion,
Expand Down Expand Up @@ -398,6 +415,23 @@ std::shared_ptr<MatmulParams> getMatmulHeuristics(
// Disable magic zero for matmul kernels
params->cparams.enable_magic_zero = false;

// Disable shared memory epilogue before shared memory reuse is implemented.
// Otherwise, there will be performance regression due to reduced occupancy
// caused by extra shared memory usage.
constexpr bool allow_smem_epilogue = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still true with the new hasEnoughSharedMemoryForEpilogue that takes register occupancy into consideration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case FusionAmpereMatmulSmemEpilogue_CUDA, The settings of cta_tile (64, 128, 32), warp_tile(32, 32, 32), and smem_double_buffer_stage(2) are purposefully changed to produce a constant occupancy of 25% no matter has_smem_epilogue = true or false. This allows me to evaluate the influence of the has_smem_epilogue. There is no noticable change of performance for TT, NT, TN (kernel1-3). But for NN (kernel4), there is a slightly decrease. I checked the global memory write index, the access is coalesced for has_smem_epilogue = true and strided for has_smem_epilogue = false

has_smem_epilogue = false
smem_false_64_128_32.log:kernel1 run in 1.19603 ms, achieved: 112.219 GB/s
smem_false_64_128_32.log:kernel2 run in 1.18682 ms, achieved: 113.091 GB/s
smem_false_64_128_32.log:kernel3 run in 1.00045 ms, achieved: 134.158 GB/s
smem_false_64_128_32.log:kernel4 run in 1.16224 ms, achieved: 115.482 GB/s

has_smem_epilogue = true
smem_true_64_128_32.log:kernel1 run in 1.19296 ms, achieved: 112.508 GB/s
smem_true_64_128_32.log:kernel2 run in 1.19091 ms, achieved: 112.702 GB/s
smem_true_64_128_32.log:kernel3 run in 0.995328 ms, achieved: 134.848 GB/s
smem_true_64_128_32.log:kernel4 run in 1.18067 ms, achieved: 113.679 GB/s

Global Memory write index for MatmulLayout::NN with has_smem_epilogue = false:

smem_false4_64_128_32_write_index
tidyz=00 tidx=0 from=0 to=0
tidyz=00 tidx=1 from=0 to=2
tidyz=00 tidx=2 from=0 to=4
tidyz=00 tidx=3 from=0 to=6
tidyz=00 tidx=4 from=0 to=4096
tidyz=00 tidx=5 from=0 to=4098
tidyz=00 tidx=6 from=0 to=4100
tidyz=00 tidx=7 from=0 to=4102
tidyz=00 tidx=8 from=0 to=8192
tidyz=00 tidx=9 from=0 to=8194
tidyz=00 tidx=10 from=0 to=8196
tidyz=00 tidx=11 from=0 to=8198
tidyz=00 tidx=12 from=0 to=12288
tidyz=00 tidx=13 from=0 to=12290
tidyz=00 tidx=14 from=0 to=12292
tidyz=00 tidx=15 from=0 to=12294
tidyz=00 tidx=16 from=0 to=16384
tidyz=00 tidx=17 from=0 to=16386
tidyz=00 tidx=18 from=0 to=16388
tidyz=00 tidx=19 from=0 to=16390
tidyz=00 tidx=20 from=0 to=20480
tidyz=00 tidx=21 from=0 to=20482
tidyz=00 tidx=22 from=0 to=20484
tidyz=00 tidx=23 from=0 to=20486
tidyz=00 tidx=24 from=0 to=24576
tidyz=00 tidx=25 from=0 to=24578
tidyz=00 tidx=26 from=0 to=24580
tidyz=00 tidx=27 from=0 to=24582
tidyz=00 tidx=28 from=0 to=28672
tidyz=00 tidx=29 from=0 to=28674
tidyz=00 tidx=30 from=0 to=28676
tidyz=00 tidx=31 from=0 to=28678

Global Memory write index for MatmulLayout::NN with has_smem_epilogue = true:

smem_true4_64_128_32_write_index
tidyz=00 tidx=0 from=0 to=0
tidyz=00 tidx=1 from=4 to=4
tidyz=00 tidx=2 from=8 to=8
tidyz=00 tidx=3 from=12 to=12
tidyz=00 tidx=4 from=16 to=16
tidyz=00 tidx=5 from=20 to=20
tidyz=00 tidx=6 from=24 to=24
tidyz=00 tidx=7 from=28 to=28
tidyz=00 tidx=8 from=32 to=32
tidyz=00 tidx=9 from=36 to=36
tidyz=00 tidx=10 from=40 to=40
tidyz=00 tidx=11 from=44 to=44
tidyz=00 tidx=12 from=48 to=48
tidyz=00 tidx=13 from=52 to=52
tidyz=00 tidx=14 from=56 to=56
tidyz=00 tidx=15 from=60 to=60
tidyz=00 tidx=16 from=64 to=64
tidyz=00 tidx=17 from=68 to=68
tidyz=00 tidx=18 from=72 to=72
tidyz=00 tidx=19 from=76 to=76
tidyz=00 tidx=20 from=80 to=80
tidyz=00 tidx=21 from=84 to=84
tidyz=00 tidx=22 from=88 to=88
tidyz=00 tidx=23 from=92 to=92
tidyz=00 tidx=24 from=96 to=96
tidyz=00 tidx=25 from=100 to=100
tidyz=00 tidx=26 from=104 to=104
tidyz=00 tidx=27 from=108 to=108
tidyz=00 tidx=28 from=112 to=112
tidyz=00 tidx=29 from=116 to=116
tidyz=00 tidx=30 from=120 to=120
tidyz=00 tidx=31 from=124 to=124


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so changing has_smem_epilogue to true makes the memory access contiguous, but the perf drop? That is strange.

if (allow_smem_epilogue) {
const auto& roles_map_opt = mma_utils::getTensorsRoles(fusion);
TORCH_INTERNAL_ASSERT(
roles_map_opt.isValid(), "Tensor roles map in mma is not valid.");
// Check if we have enough shared memory for epilogue
params->has_smem_epilogue = mma_utils::hasEnoughSharedMemoryForEpilogue(
params->tile_sizes,
params->double_buffer_options.smem_double_buffer_stage,
getMmaDataTypes(roles_map_opt.getData()));
} else {
params->has_smem_epilogue = false;
}

if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) {
printMsg(params->toString());
}
Expand Down
60 changes: 50 additions & 10 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on

#include <ATen/cuda/CUDAContext.h>
#include <device_lower/utils.h>
#include <expr_evaluator.h>
#include <ir/printer.h>
Expand All @@ -14,11 +15,49 @@
#include <scheduler/utils.h>
#include <variant>
#include "mma_type.h"

namespace nvfuser {

namespace mma_utils {

bool hasEnoughSharedMemoryForEpilogue(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage,
const MmaDataTypes& data_types) {
const auto properties = at::cuda::getCurrentDeviceProperties();
const size_t device_smem_limit = properties->sharedMemPerBlockOptin;

auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile;
const auto threads_per_block =
warp_dims.m * warp_dims.n * warp_dims.k * properties->warpSize;
// a thread can use up to 255 registers, blocks per sm is limited by available
// registers
const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255);
const auto blocks_per_sm = threads_per_sm / threads_per_block;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to blocks_per_sm_by_register

// see scheduleContiguousVectorLoad
const int vector_word = 8;
const int round_to_factor = warp_dims.m * warp_dims.n * warp_dims.k *
properties->warpSize * vector_word;
const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k;
const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k;
const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) *
round_to_factor * smem_double_buffer_stage) *
dataTypeSize(data_types[0]);
const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) *
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
round_to_factor * smem_double_buffer_stage) *
dataTypeSize(data_types[1]);
const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) *
dataTypeSize(data_types[2]);

// use additional shared memory for epilogue if blocks per sm is not changed
const auto blocks_per_sm_without_smem_epilogue =
std::min(device_smem_limit / (smem_a + smem_b), (size_t)blocks_per_sm);
const auto blocks_per_sm_with_smem_epilogue = std::min(
device_smem_limit / (smem_a + smem_b + smem_c), (size_t)blocks_per_sm);
return blocks_per_sm_with_smem_epilogue ==
blocks_per_sm_without_smem_epilogue ||
blocks_per_sm_with_smem_epilogue > 0;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|| blocks_per_sm_with_smem_epilogue > 0; is a tmp condition to allow the test of as many cases as possbile without considering performance. Will be deleted before merge.

}

void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
// Assumes
// [M, N, K]
Expand Down Expand Up @@ -379,11 +418,6 @@ bool canValidateIsInnerDim(
if (!split->factor()->isConstInt()) {
return false;
}
if (split->factor()->evaluateInt() < inner_dim_size) {
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
// This might be too restrictive. Would need more
// bookkeeping to relax.
return false;
}
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
leaf = split->in();
} else if (auto merge = dynamic_cast<Merge*>(expr)) {
// Might consider just rejecting merge.
Expand All @@ -396,9 +430,6 @@ bool canValidateIsInnerDim(
if (!leaf->extent()->isConstInt()) {
return false;
}
if (leaf->extent()->evaluateInt() != inner_dim_size) {
return false;
}
leaf = merge->inner();
} else {
// No support for swizzled inner dim for now.
Expand Down Expand Up @@ -438,7 +469,9 @@ void checkDimSize(
":",
id->extent()->evaluateInt(),
"vs",
expect[axis_index]);
expect[axis_index],
"\n for tv: ",
tv->toString());
}
}

Expand Down Expand Up @@ -699,6 +732,13 @@ void validateMmaRootInnerMNK(
//! swizzles to the right axes.
//! This check will be relaxed as we build out the mma usage patterns.
void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) {
auto is_mma_output =
tv->definition() != nullptr && tv->definition()->isA<MmaOp>();
// This function is also used to transform epilogue tensor. It is not a mma
// output and can skip the following checks.
if (!is_mma_output) {
return;
}
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
auto mma = options.mmaOp();
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
Expand Down
10 changes: 10 additions & 0 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ using ProblemIterDomains = std::array<IterDomain*, 3>;
//! a single tv, for example input for beta scaling in epilogue
using RolesMap = std::map<MatmulRole, std::vector<TensorView*>>;

//! An alias for storing data types of the tensors in the mma op
//! the order is INPUT_A, INPUT_B, OUTPUT_D
using MmaDataTypes = std::array<DataType, 3>;

//! A wrapper for data containers with optional error message stored if
//! initialization of the data fails.
template <typename DataType>
Expand Down Expand Up @@ -289,6 +293,12 @@ TORCH_CUDA_CU_API ProblemIterDomainsOpt getProblemIterDomains(Fusion* fusion);
//! be gathered.
TORCH_CUDA_CU_API RolesMapOpt getTensorsRoles(Fusion* fusion);

//! Check if there is enough shared memory for the given tile options
TORCH_CUDA_CU_API bool hasEnoughSharedMemoryForEpilogue(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage,
const MmaDataTypes& data_types);

} // namespace mma_utils

} // namespace nvfuser
Loading