Skip to content

Commit

Permalink
rebase fix (as promised :) ). see FusionAmpereMatmulLargeLoad for a w…
Browse files Browse the repository at this point in the history
…orking example.
  • Loading branch information
shmsong committed Sep 24, 2022
1 parent 1e5e745 commit 2e37ff1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 114 deletions.
25 changes: 20 additions & 5 deletions torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ void scheduleMatmul(
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(
cc, gemm_tile, true);

// Move the Mw up for epilog:
if (params.has_epilog) {
cc->reorder({{4, 5}, {5, 6}, {6, 4}});
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kw Mw Mwo Nwo Nw (Mi Ni Ki)]
}

// Propagate warp tile to main loop and epilog/output tvs
scheduler_utils::BoundedDirectionalTransformPropagator::bothWays(
cc, -1, {acw_smem, bcw_smem}, {c});
Expand Down Expand Up @@ -509,8 +516,15 @@ void scheduleMatmul(
// [Mo No Ko Kw Mw Nw Mwo Nwo(Mi Ni Ki)]
cc->axis(0)->parallelize(ParallelType::BIDx);
cc->axis(1)->parallelize(ParallelType::BIDy);
cc->axis(6)->parallelize(ParallelType::TIDz);
cc->axis(7)->parallelize(ParallelType::TIDy);

// Maybe just keep one of these two options.
if (params.has_epilog) {
cc->axis(5)->parallelize(ParallelType::TIDz);
cc->axis(6)->parallelize(ParallelType::TIDy);
} else {
cc->axis(4)->parallelize(ParallelType::TIDz);
cc->axis(5)->parallelize(ParallelType::TIDy);
}

scheduler_utils::parallelizeAllLike(
cc,
Expand Down Expand Up @@ -560,15 +574,15 @@ void scheduleMatmul(
if (params.has_epilog) {
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
c_smem,
3,
4,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

c_smem->computeAt(c, 3);
c_smem->computeAt(c, 4);
c->reorder({{-1, -2}, {-2, -1}});
// 16 x 128, with half of the warps:
// 16 x 128, with 2 warps:

// Output vectorize by 4:
c->split(-2, 2);
Expand All @@ -577,6 +591,7 @@ void scheduleMatmul(
// [8, 2, 32, 4]
c->axis(-3)->parallelize(ParallelType::TIDy);
c->axis(-2)->parallelize(ParallelType::TIDx);

c->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->doubleBuffer();
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MatmulParam {
bool peel_main_loop = true;

//! Enables an epilog schedule
bool has_epilog = false;
bool has_epilog = true;
};

//! Prototype auto scheduling function.
Expand Down
108 changes: 0 additions & 108 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2826,114 +2826,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) {
}
}

TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity1_CUDA) {
// Keep multiples of 8 to keep vectorizable.
int M = 2048, N = 3456, K = 1024;
for (auto layout : kAllSupportedLayout) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeContigTensor(2, DataType::Half);
auto tv1 = makeContigTensor(2, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1, layout);

fusion.addOutput(tv2);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 128, 64);
gemm_tile.warp_tile = GemmTile(64, 64, 64);
gemm_tile.instruction_tile = GemmTile(16, 16, 16);

auto mma_builder =
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
.layout(layout);

MatmulParam params(mma_builder);
params.tile_sizes = gemm_tile;
params.async_gmem_load_operands = true;
params.double_buffer_options.double_buffer_smem_write = true;
params.double_buffer_options.double_buffer_smem_read = true;
params.double_buffer_options.smem_double_buffer_stage = 3;
scheduleMatmul(tv2, tv0, tv1, params);

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

CompileOptions co;
co.index_mode = KernelIndexMode::INT32;

FusionExecutor fe;
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8,
0,
fe.compileFusion(
&fusion, {inputs.first, inputs.second}, LaunchParams(), co));

// return;
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001));
}
}

TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity2_CUDA) {
// Keep multiples of 8 to keep vectorizable.
int M = 2048, N = 3456, K = 1024;
for (auto layout : kAllSupportedLayout) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeContigTensor(2, DataType::Half);
auto tv1 = makeContigTensor(2, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1, layout);

fusion.addOutput(tv2);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(256, 128, 64);
gemm_tile.warp_tile = GemmTile(64, 64, 64);
gemm_tile.instruction_tile = GemmTile(16, 16, 16);

auto mma_builder =
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
.layout(layout);

MatmulParam params(mma_builder);
params.tile_sizes = gemm_tile;
params.async_gmem_load_operands = true;
params.double_buffer_options.double_buffer_smem_write = true;
params.double_buffer_options.double_buffer_smem_read = true;
params.double_buffer_options.smem_double_buffer_stage = 2;
scheduleMatmul(tv2, tv0, tv1, params);

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

CompileOptions co;
co.index_mode = KernelIndexMode::INT32;

FusionExecutor fe;
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8,
0,
fe.compileFusion(
&fusion, {inputs.first, inputs.second}, LaunchParams(), co));

// return;
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001));
}
}

// Tile layout check for symmetric 4-warp recipes
TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) {
REQUIRE_DEVICE_SMEM_SIZE(98384, 0);
Expand Down

0 comments on commit 2e37ff1

Please sign in to comment.