Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add epilogue to store MMA results in shared memory before write to #387

Merged
merged 36 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
379eab5
add epilogue to store MMA results in shared memory before write to
liqiangxl May 22, 2023
b500d36
revise test
liqiangxl Jun 8, 2023
acf1167
format
liqiangxl Jun 8, 2023
be885b0
swizzleSharedMemory
liqiangxl Jun 20, 2023
7138cbe
format
liqiangxl Jun 20, 2023
fad4ad8
fix failed test cases
liqiangxl Jun 21, 2023
84e3e98
propagate to epilogue tensors
liqiangxl Jun 28, 2023
a94e5df
check num_shared_mem_tensors
liqiangxl Jun 28, 2023
9f4bcc4
format
liqiangxl Jun 28, 2023
1760c15
disable_smem_epilogue
liqiangxl Jun 28, 2023
f0ff6f9
extend MatmulSASSTest
liqiangxl Jun 28, 2023
da5dc3a
schedule output tensor
liqiangxl Jun 29, 2023
537b855
wip
liqiangxl Jun 29, 2023
ab86a1f
use propagate
liqiangxl Jul 1, 2023
f2a75cd
fix failed case
liqiangxl Jul 3, 2023
7a4d5b5
fix ci fails by increasing tolerance:x
liqiangxl Jul 4, 2023
86b8911
merge main
liqiangxl Jul 7, 2023
5586b3a
fix failed cases
liqiangxl Jul 7, 2023
6a8f139
trivial fix
liqiangxl Jul 7, 2023
d6212cb
format
liqiangxl Jul 9, 2023
95ea553
revise hasEnoughSharedMemoryForEpilogue
liqiangxl Jul 9, 2023
925e04d
merge main
liqiangxl Jul 9, 2023
1f30a36
wip
liqiangxl Jul 10, 2023
80d7588
cacheAfter mma_result
liqiangxl Jul 10, 2023
d3019f0
add epilogue cast and relu tests
liqiangxl Jul 10, 2023
212258c
trivial fix
liqiangxl Jul 10, 2023
a2045cd
mma data types
liqiangxl Jul 12, 2023
32f43d8
merge main
liqiangxl Jul 12, 2023
67ecdb0
revise smem swizzle
liqiangxl Jul 13, 2023
864a918
test with revised swizzle
liqiangxl Jul 13, 2023
79c452c
merge main
liqiangxl Jul 14, 2023
26d970d
save file
liqiangxl Jul 14, 2023
189cef4
revise based on review comments
liqiangxl Jul 17, 2023
889ce23
rename
liqiangxl Jul 17, 2023
94aae9b
change default to false
liqiangxl Jul 17, 2023
27870d0
Merge branch 'main' into llu/matmul_epilogue
liqiangxl Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 252 additions & 89 deletions csrc/scheduler/matmul.cpp

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class MatmulParams : public HeuristicParams {
//! C3 C4 D3 D4
int grid_swizzle_factor = 1;

//! Unswizzle MMA results in shared memory to get
//! coalesced write to global memory
bool has_smem_epilogue = true;

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Matmul Parameters ========\n"
Expand Down
11 changes: 11 additions & 0 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ std::shared_ptr<MatmulParams> getMatmulHeuristics(
// Disable magic zero for matmul kernels
params->cparams.enable_magic_zero = false;

// Disable shared memory epilogue before shared memory reuse is implemented.
// Otherwise, there will be performance regression due to reduced occupancy
// caused by extra shared memory usage.
constexpr bool disable_smem_epilogue = false;
if (!disable_smem_epilogue) {
// Check if we have enough shared memory for epilogue
params->has_smem_epilogue = mma_utils::hasEnoughSharedMemoryForEpilogue(
params->tile_sizes,
params->double_buffer_options.smem_double_buffer_stage);
}

if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) {
printMsg(params->toString());
}
Expand Down
48 changes: 38 additions & 10 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on

#include <ATen/cuda/CUDAContext.h>
#include <device_lower/utils.h>
#include <expr_evaluator.h>
#include <ir/printer.h>
Expand All @@ -14,11 +15,37 @@
#include <scheduler/utils.h>
#include <variant>
#include "mma_type.h"

namespace nvfuser {

namespace mma_utils {

bool hasEnoughSharedMemoryForEpilogue(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage) {
auto properties = at::cuda::getDeviceProperties(
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
c10::Device(c10::DeviceType::CUDA, 0).index());
const size_t device_smem_limit = properties->sharedMemPerBlockOptin;

// see scheduleContiguousVectorLoad
const int vector_word = 8;
auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile;
const int round_to_factor =
warp_dims.m * warp_dims.n * warp_dims.k * 32 * vector_word;
const int mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k;
const int nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k;
const size_t smem_a = (size_t)(ceilDiv(mk, round_to_factor) *
round_to_factor * smem_double_buffer_stage) *
dataTypeSize(DataType::Half);
const size_t smem_b = (size_t)(ceilDiv(nk, round_to_factor) *
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
round_to_factor * smem_double_buffer_stage) *
dataTypeSize(DataType::Half);
const size_t smem_c = (size_t)(gemm_tile.cta_tile.m * gemm_tile.cta_tile.n) *
dataTypeSize(DataType::Float);
const size_t smem_size = smem_a + smem_b + smem_c;

return smem_size <= device_smem_limit;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are checking if smem_size <= device_smem_limit and enable smem unswizzle if this condition is satisfied. I don't think this is the correct behavior. Each kernel has its designed occupancy, if the designed occupancy is N, then we should have N*smem_size <= device_smem_limit. cc: @drzejan2 Any opinion on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right @zasdfgbnm

The first though that comes to my mind is to calculate current occupancy where only smem_a and smem_b are in use and then check if new occupancy, with smem_a, smem_b and smem_c, does not change the occupancy. Something along these lines:

const auto curr_smem_occupancy = device_smem_limit / (smem_a + smem_b);
const auto new_smem_occupancy = device_smem_limit / (smem_a + smem_b + smem_c);

return new_smem_occupancy <= curr_smem_occupancy;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. But without memory reuse new_smem_occupancy < curr_smem_occupancy is almost always true. e.g. I checked NVFuserTest.FusionAmpereMatmulTileCheck4warp_CUDA and the results are:

curr_smem_occupancy= 20 new_smem_occupancy= 13
curr_smem_occupancy= 10 new_smem_occupancy= 8
curr_smem_occupancy= 10 new_smem_occupancy= 8
curr_smem_occupancy= 10 new_smem_occupancy= 5
curr_smem_occupancy= 6 new_smem_occupancy= 4
curr_smem_occupancy= 5 new_smem_occupancy= 3
curr_smem_occupancy= 6 new_smem_occupancy= 2
curr_smem_occupancy= 4 new_smem_occupancy= 2
curr_smem_occupancy= 3 new_smem_occupancy= 1
curr_smem_occupancy= 5 new_smem_occupancy= 1
curr_smem_occupancy= 3 new_smem_occupancy= 1
curr_smem_occupancy= 2 new_smem_occupancy= 1
curr_smem_occupancy= 4 new_smem_occupancy= 1
curr_smem_occupancy= 2 new_smem_occupancy= 0

This function was designed to ensure we can launch the kernel with shared memory epilogue without considering the occupancy. By doing this we can increase the number of cases can be tested with smem epilogue. For performance purpose we definitly need N*smem_size <= device_smem_limit, in this revision N is derived from register limitation. For case NVFuserTest.FusionAmpereMatmulTileCheck4warp_CUDA N is 1 or 2.

blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 1 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 1 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 1 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 2
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 1
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 2 blocks_per_sm_with_smem_epilogue= 0
blocks_per_sm_without_smem_epilogue= 1 blocks_per_sm_with_smem_epilogue= 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @liqiangxl for running verification smem-based and register-based occupancy. The updated implementation looks good.

}

void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
// Assumes
// [M, N, K]
Expand Down Expand Up @@ -379,11 +406,6 @@ bool canValidateIsInnerDim(
if (!split->factor()->isConstInt()) {
return false;
}
if (split->factor()->evaluateInt() < inner_dim_size) {
drzejan2 marked this conversation as resolved.
Show resolved Hide resolved
// This might be too restrictive. Would need more
// bookkeeping to relax.
return false;
}
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
leaf = split->in();
} else if (auto merge = dynamic_cast<Merge*>(expr)) {
// Might consider just rejecting merge.
Expand All @@ -396,9 +418,6 @@ bool canValidateIsInnerDim(
if (!leaf->extent()->isConstInt()) {
return false;
}
if (leaf->extent()->evaluateInt() != inner_dim_size) {
return false;
}
leaf = merge->inner();
} else {
// No support for swizzled inner dim for now.
Expand Down Expand Up @@ -438,7 +457,9 @@ void checkDimSize(
":",
id->extent()->evaluateInt(),
"vs",
expect[axis_index]);
expect[axis_index],
"\n for tv: ",
tv->toString());
}
}

Expand Down Expand Up @@ -699,6 +720,13 @@ void validateMmaRootInnerMNK(
//! swizzles to the right axes.
//! This check will be relaxed as we build out the mma usage patterns.
void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) {
auto is_mma_output =
tv->definition() != nullptr && tv->definition()->isA<MmaOp>();
// This function is also used to transform epilogue tensor. It is not a mma
// output and can skip the following checks.
if (!is_mma_output) {
return;
}
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
auto mma = options.mmaOp();
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
Expand Down
5 changes: 5 additions & 0 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ namespace nvfuser {

namespace mma_utils {

//! Check if there is enough shared memory for the given tile options
TORCH_CUDA_CU_API bool hasEnoughSharedMemoryForEpilogue(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage);

//! Utilities in this namespace facilitates scheduling matmul kernels with
//! hierarchichal tiling specified in MatMulTileOptions.

Expand Down
16 changes: 10 additions & 6 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ std::pair<TensorDomain*, size_t> TransformReplay::replayPasC(
consumer,
(int)consumer_pos,
root_map,
false,
opt.skip_target_swizzle,
!opt.replay_swizzle,
!opt.replay_resize);

Expand Down Expand Up @@ -609,7 +609,7 @@ std::pair<TensorDomain*, size_t> TransformReplay::replayCasP(
producer,
(int)producer_pos,
root_map,
false,
opt.skip_target_swizzle,
!opt.replay_swizzle,
!opt.replay_resize);

Expand Down Expand Up @@ -1085,7 +1085,8 @@ void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) {
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
auto replay = TransformReplay::replayPasC(
to, from, pos, TransformReplayOptions().skipTargetSwizzle());
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
Expand Down Expand Up @@ -1116,7 +1117,8 @@ void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) {
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
auto replay = TransformReplay::replayCasP(
to, from, pos, TransformReplayOptions().skipTargetSwizzle());
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
Expand Down Expand Up @@ -1187,7 +1189,8 @@ void MostInlinedTransformPropagator::propagateC2P(
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayPasC(to, from, pos);
auto replay = TransformReplay::replayPasC(
to, from, pos, TransformReplayOptions().skipTargetSwizzle());
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
Expand Down Expand Up @@ -1218,7 +1221,8 @@ void MostInlinedTransformPropagator::propagateP2C(
std::cout << " to: " << to << std::endl;
}
if (new_pos < 0) {
auto replay = TransformReplay::replayCasP(to, from, pos);
auto replay = TransformReplay::replayCasP(
to, from, pos, TransformReplayOptions().skipTargetSwizzle());
TORCH_INTERNAL_ASSERT(
validateDomain(to, replay.first),
"Tried to set the domain of ",
Expand Down
31 changes: 31 additions & 0 deletions csrc/transform_replay.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,41 @@ class TensorView;
class RootDomainMap;

struct TransformReplayOptions {
// In theory, it makes more sense to have skip_target_swizzle = true by
// default because this is how we index into the producer and how we propagate
// transformations. However, we are in a very funny situation that:
// BestEffortReplay for swizzle is broken. For example, if we have a
// producer <=> consumer pair like:
// I1 I0
// / \ / |
// I1o I1i I0o I0i
// | | | |
// swizzle I1i swizzle I0i <=> I3 I2
// | | | |
// I1o' I1i I0o' I0i
// \ / \ /
// I1' I0'
// where I1o', I0o' = swizzle(I1o, I0o), we never really skipped swizzle to
// map I1' with I3 and I0' with I2. But even with this error, our swizzle
// indexing worked due to luck. So effectively we were doing
// skip_target_swizzle = false. But today, we can not make this `true` for
// vectorization validation and indexing, because of another bug in
// BestEffortReplay: swizzle skip should happen in an all-or-nothing fashion.
// We can not just skip X but not skip Y, but we are not implementing this
// skip like that. If we make it `true`, this will trigger some error in some
// schedule. So here, in order to avoid exposing one bug, we are more
// explicitly using a wrong behavior that we have been using because this
// wrong behavior has a better luck.
bool skip_target_swizzle = false;
bool replay_swizzle = false;
bool replay_resize = false;
bool replay_allocation = false;

TransformReplayOptions& skipTargetSwizzle(bool value = true) {
skip_target_swizzle = value;
return *this;
}

TransformReplayOptions& replaySwizzle(bool value = true) {
replay_swizzle = value;
return *this;
Expand Down
Loading