Skip to content

Commit

Permalink
optionally enable epilog schedule (for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Sep 11, 2022
1 parent 77f831b commit 97c6ee9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 27 deletions.
68 changes: 41 additions & 27 deletions torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,12 @@ void scheduleMatmul(
// Setup accumulator register.
auto cc = c->cacheBefore();

// Setup output smem buffer
auto c_smem = c->cacheBefore();
TensorView* c_smem = nullptr;

if (params.has_epilog) {
// Setup output smem buffer
c_smem = c->cacheBefore();
}

// Get the input to the mma op.
auto mma = dynamic_cast<MmaOp*>(cc->definition());
Expand Down Expand Up @@ -483,7 +487,10 @@ void scheduleMatmul(
// Set memory type:
acw_smem->setMemoryType(MemoryType::Shared);
bcw_smem->setMemoryType(MemoryType::Shared);
c_smem->setMemoryType(MemoryType::Shared);

if (params.has_epilog) {
c_smem->setMemoryType(MemoryType::Shared);
}

// Set parallelization:
// TODO: this section goes to a separate matmul util,
Expand Down Expand Up @@ -534,37 +541,44 @@ void scheduleMatmul(
bcr->skewDoubleBuffer();
}

auto output_buffer = params.has_epilog ? c_smem : c;

scheduler_utils::BoundedDirectionalTransformPropagator::forward(
cc,
-1,
{c_smem},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

// Epilog schedule:
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
c_smem,
3,
{c},
{output_buffer},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

c_smem->computeAt(c, 3);
c->reorder({{-1, -2}, {-2, -1}});
// 16 x 128, with half of the warps:

// Output vectorize by 4:
c->split(-2, 2);
c->split(-1, 4);

// [8, 2, 32, 4]
c->axis(-3)->parallelize(ParallelType::TIDy);
c->axis(-2)->parallelize(ParallelType::TIDx);
c->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->doubleBuffer();
// Epilog schedule (To be built out):
if (params.has_epilog) {
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
c_smem,
3,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

c_smem->computeAt(c, 3);
c->reorder({{-1, -2}, {-2, -1}});
// 16 x 128, with half of the warps:

// Output vectorize by 4:
c->split(-2, 2);
c->split(-1, 4);

// [8, 2, 32, 4]
c->axis(-3)->parallelize(ParallelType::TIDy);
c->axis(-2)->parallelize(ParallelType::TIDx);
c->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
c_smem->doubleBuffer();
} else {
// Always vector
c->axis(-1)->parallelize(ParallelType::Vectorize);
}

if (params.index_lift_options.lift_gmem_read_address) {
a->liftReadAddress();
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class MatmulParam {

//! Enables predicate peeling mainloop:
bool peel_main_loop = true;

//! Enables an epilog schedule
bool has_epilog = false;
};

//! Prototype auto scheduling function.
Expand Down

0 comments on commit 97c6ee9

Please sign in to comment.