Skip to content

Commit

Permalink
add epilogue to store MMA results in shared memory before write to
Browse files Browse the repository at this point in the history
global memory to achieve coalesced write.
TODO: remove bank conflict, reuse shared memory
  • Loading branch information
liqiangxl committed May 23, 2023
1 parent 842d543 commit a8639d4
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 14 deletions.
79 changes: 65 additions & 14 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
// Setup accumulator register.
auto cc = c->cacheBefore();

// Setup output smem buffer
// cc -> c_smem -> c
TensorView* c_smem = nullptr;
if (params.has_epilogue) {
c_smem = c->cacheBefore();
}

// Get the input to the mma op.
auto mma = cc->definition()->as<MmaOp>();
auto ab = mma->inA()->as<TensorView>();
Expand Down Expand Up @@ -652,6 +659,17 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {

// Schedule warp tile
mma_utils::scheduleWarpTileWithReduction(cc, gemm_tile);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kw Mwo Nwo Mwi Nwi Mi, Ni, Ki]
// e.g. T6_l[ iS25{( ceilDiv(i0, 128) )}, iS27{( ceilDiv(i4, 128) )}, rS29{(
// ceilDiv(i2, 32) )}, rS83{( ceilDiv(32, 16) )}, iS75{( ceilDiv(128, 64) )},
// iS77{( ceilDiv(128, 64) )}, iS79{( ceilDiv(64, 16) )}, iS81{( ceilDiv(64,
// 8) )}, iS80{16}, iS82{8}, rS84{16} ] Move the Mw up for epilog:
if (params.has_epilogue) {
cc->reorder({{4, 5}, {5, 6}, {6, 4}});
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kw Mw Mwo Nwo Nw (Mi Ni Ki)]
}

// Propagate warp tile to main loop and epilog/output tvs
scheduler_utils::BoundedDirectionalTransformPropagator::bothWays(
Expand All @@ -662,7 +680,9 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
// ------------------------------------------------------------------
scheduleProlog(acw_smem, params);
scheduleProlog(bcw_smem, params);

if (params.has_epilogue) {
c_smem->setMemoryType(MemoryType::Shared);
}
// Add mma swizzle:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
Expand Down Expand Up @@ -718,17 +738,58 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
false, "Invalid TileRasterizationOrder passed to Matmul scheduler");
}

cc->axis(4)->parallelize(ParallelType::TIDz);
cc->axis(5)->parallelize(ParallelType::TIDy);
// iS75{( ceilDiv(128, 64) )}, iS77{( ceilDiv(128, 64) )}
if (params.has_epilogue) {
cc->axis(5)->parallelize(ParallelType::TIDz);
cc->axis(6)->parallelize(ParallelType::TIDy);
} else {
cc->axis(4)->parallelize(ParallelType::TIDz);
cc->axis(5)->parallelize(ParallelType::TIDy);
}

scheduler_utils::parallelizeAllLike(
cc,
-1,
{acr, bcr, ab, bb, a, b},
{ParallelType::TIDy, ParallelType::TIDz});

auto output_buffer = params.has_epilogue ? c_smem : c;
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
cc,
-1,
{output_buffer},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

// Epilog schedule (To be built out):
if (params.has_epilogue) {
// T7_s[ iblockIdx.x67{( ceilDiv(i0, 128) )}, iblockIdx.y69{( ceilDiv(i4,
// 128) )}, iS119{( ceilDiv(64, 16) )}, ithreadIdx.z117{( ceilDiv(128, 64)
// )}, ithreadIdx.y121{( ceilDiv(128, 64) )}, iS123{( ceilDiv(64, 8) )},
// iS194{( ceilDiv(16, 8) )}, ithreadIdx.x198{( 8 * ( ceilDiv(8, 2) ) )},
// iS197{2} ]
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
c_smem,
4,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
c->reorder({{-1, -2}, {-2, -1}});
c->split(-2, 2);
c->split(-1, 4);
// [8, 2, 32, 4]
c->axis(-3)->parallelize(ParallelType::TIDy);
c->axis(-2)->parallelize(ParallelType::TIDx);
c->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
} else {
// Always vector
c->axis(-1)->parallelize(ParallelType::Vectorize);
}
// auto inline for all tensors except register tensors and output tensor
inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb, c}));
inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb, c_smem, c}));

// if auto inline, will inline to position-7, leads to performance regression
inlineSelectedAt({acr, bcr, ab, bb}, cc, 6);
Expand All @@ -755,16 +816,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
bcr->doubleBuffer();
}

scheduler_utils::BoundedDirectionalTransformPropagator::forward(
cc,
-1,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

c->axis(-1)->parallelize(ParallelType::Vectorize);

if (params.double_buffer_options.double_buffer_smem_read &&
params.double_buffer_options.double_buffer_smem_write) {
scheduler_utils::rotateLoop(cc, 2, {acr, bcr});
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class MatmulParams : public HeuristicParams {
//! C3 C4 D3 D4
int grid_swizzle_factor = 1;

//! swizzle MMA results in shared memory
//! coalesced write to global memory
bool has_epilogue = false;

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Matmul Parameters ========\n"
Expand Down
63 changes: 63 additions & 0 deletions test/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3460,6 +3460,69 @@ TEST_F(NVFuserTest, FusionMatmulSegmenterBasicMatmulRelaxedCheck_CUDA) {
}
}

TEST_F(NVFuserTest, FusionAmpereMatmulEpilogue_CUDA) {
// Keep multiples of 8 to keep vectorizable.
int M = 504, N = 136, K = 248;

for (auto layout : kAllSupportedMatmulLayout) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeContigTensor(2, DataType::Half);
auto tv1 = makeContigTensor(2, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1, layout, true);

fusion.addOutput(tv2);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 128, 32);
gemm_tile.warp_tile = GemmTile(64, 64, 32);
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams params;
params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16;
params.tile_sizes = gemm_tile;
params.has_epilogue = true;
params.async_gmem_load_operands = true;
// intentionally set to false to make the generated code simple
params.double_buffer_options.double_buffer_smem_write = false;
params.double_buffer_options.double_buffer_smem_read = false;
params.double_buffer_options.smem_double_buffer_stage = 1;
scheduleMatmul(&fusion, params);

at::manual_seed(0);
auto inputs = matmulAtInput(M, N, K, layout);

FusionExecutor fe;
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8,
0,
fe.compileFusion(
&fusion,
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));

// check bank conflicts
// const std::unordered_map<const Expr*, std::pair<int, int>>& bank_conflict
// = getBankCo flictInfo(fe.kernel()); if(! flict.empty()){
// for(auto it = bank_conflict.begin(); it != bank_conflict.end(); it++){
// std::cout << "Bank conflict expression: " << it->first->toString() <<
// "read confl ct= " << it->second.first << ", write conflict= " <<
// it->second. econd << std::endl;
// }
// ASSERT_TRUE(bank_conflict.empty());
// }
}
}

#undef NVFUSER_TEST_CUDA_ARCH_GUARD

} // namespace nvfuser

0 comments on commit a8639d4

Please sign in to comment.