Skip to content

Commit

Permalink
fix failed test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Jun 23, 2023
1 parent 729fc07 commit 30e0791
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 44 deletions.
127 changes: 100 additions & 27 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ void swizzleSharedMemory(
check_concrete_static_dim(shared_mem_tv->axis(-1 - shift));

// Extract the constant sizes of the swizzled tile
const int64_t tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt();
const int64_t tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt();
const int64_t tile_size_x =
shared_mem_tv->axis(-2 - shift)->extent()->evaluateInt();
const int64_t tile_size_y =
shared_mem_tv->axis(-1 - shift)->extent()->evaluateInt();

if (isTuring(params.mma_macro) || isAmpere(params.mma_macro)) {
// Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e.
Expand Down Expand Up @@ -215,7 +217,6 @@ void swizzleSharedMemory(
// assert(row_stride >= 0);
// assert(num_megabanks >= 0);
int64_t row_stride_znz = row_stride % num_megabanks;

/* Consider the following function in Z/nZ:
* f(i; init) = init + i * stride
* where init is the initial position of the pointer in the clock when we
Expand Down Expand Up @@ -366,7 +367,18 @@ void swizzleSharedMemory(
* 7| | |
* +----------+----------+
*/

int64_t swizzle_period = n_rows / repeated_pattern_size;
// tile_size_y will be splitted by n_cols and then by swizzle_period
// avoid over split, won't fully remove bank conflict if happens.
if (tile_size_y < n_cols * swizzle_period) {
swizzle_period = tile_size_y / n_cols;
repeated_pattern_size = n_rows / swizzle_period;
}
// e.g. tile_size_y = 96 in FusionAmpereMatmulTileCheck4warp_CUDA
while (tile_size_y / n_cols % swizzle_period) {
swizzle_period /= 2;
repeated_pattern_size = n_rows / swizzle_period;
}
// -2 -1
// [row, col]
TORCH_INTERNAL_ASSERT(
Expand All @@ -385,7 +397,6 @@ void swizzleSharedMemory(
}
// -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id, matrix]
int64_t swizzle_period = n_rows / repeated_pattern_size;
if (!shift) {
TORCH_INTERNAL_ASSERT(
tile_size_y % (swizzle_period * n_cols) == 0,
Expand Down Expand Up @@ -443,7 +454,6 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) {

// Swizzle the shared memory data layout
swizzleSharedMemory(shared_mem_tv, params, 0);

// Assuming we are always vectorizing smem write by 128b at the moment:
// TODO: would need a data-type and alignment dependent interface
// to support non-vectorizable shapes.
Expand All @@ -465,6 +475,7 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) {

void schedule_output_tensor(
TensorView* c,
MatmulParams::TileRasterizationOrder cta_order,
int warp_tile_m,
int instruction_tile_m) {
// [a,b,128,128]
Expand All @@ -483,8 +494,19 @@ void schedule_output_tensor(
c->reorder({{2, 3}, {3, 2}});
//[a,b,wm/im, 128/wm, im/2, 2, 128/4, 4]
int axis = 0;
c->axis(axis++)->parallelize(ParallelType::BIDx);
c->axis(axis++)->parallelize(ParallelType::BIDy);
switch (cta_order) {
case MatmulParams::TileRasterizationOrder::RowMajor:
c->axis(axis++)->parallelize(ParallelType::BIDx);
c->axis(axis++)->parallelize(ParallelType::BIDy);
break;
case MatmulParams::TileRasterizationOrder::ColumnMajor:
c->axis(axis++)->parallelize(ParallelType::BIDy);
c->axis(axis++)->parallelize(ParallelType::BIDx);
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Invalid TileRasterizationOrder passed to Matmul scheduler");
}
c->axis(axis++)->parallelize(ParallelType::Serial);
c->axis(axis++)->parallelize(ParallelType::TIDz);
c->axis(axis++)->parallelize(ParallelType::Serial);
Expand All @@ -495,7 +517,9 @@ void schedule_output_tensor(

void schedule_epilogue_tensor(
TensorView* c_smem,
const MatMulTileOptions& gemm_tile) {
MatmulParams::TileRasterizationOrder cta_order,
const MatMulTileOptions& gemm_tile,
const MmaOptions& mma_opts) {
auto warp_tile = gemm_tile.warp_tile;
auto instruction_tile = gemm_tile.instruction_tile;
// transform to its producer, mma results
Expand All @@ -513,23 +537,64 @@ void schedule_epilogue_tensor(
c_smem->reorder({{2, 3}, {3, 2}, {4, 6}, {5, 4}, {6, 5}, {7, 7}});
//[a,b,64/16,128/64,128/64,64/8, 16, 8]

// MMA
c_smem->split(-2, 8);
c_smem->split(-1, 2);
//[a,b,64/16,128/64,128/64,64/8, 16/8, 8, 8/2, 2]
c_smem->merge(-3, -2);
auto macro = mma_opts.macro;
int m_pos = -2;
switch (macro) {
case MmaOptions::MacroType::Volta_16_16_4:
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
// [16, 8]
c_smem->split(-2, 8);
c_smem->split(-1, 2);
c_smem->merge(-3, -2);
c_smem->axis(m_pos)->parallelize(ParallelType::TIDx);
break;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
// m
// [16, 16 (,R)]
c_smem->split(m_pos + 1, 8);
// m
// [16, n2, 8 (,R)]
c_smem->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}});

// m
// [n2, 16, 8 (,R)]
c_smem->split(m_pos, 8);
c_smem->split(m_pos + 1, 2);

// m
// [2o, 8o, 4i, 2i (,R)]
c_smem->merge(m_pos - 1);
c_smem->axis(m_pos)->parallelize(ParallelType::TIDx);

break;
default:
TORCH_CHECK(
false, "scheduleMmaWarp: unsupported mma option ", toString(macro));
break;
}

// parallel
int axis = 0;
c_smem->axis(axis++)->parallelize(ParallelType::BIDx);
c_smem->axis(axis++)->parallelize(ParallelType::BIDy);
switch (cta_order) {
case MatmulParams::TileRasterizationOrder::RowMajor:
c_smem->axis(axis++)->parallelize(ParallelType::BIDx);
c_smem->axis(axis++)->parallelize(ParallelType::BIDy);
break;
case MatmulParams::TileRasterizationOrder::ColumnMajor:
c_smem->axis(axis++)->parallelize(ParallelType::BIDy);
c_smem->axis(axis++)->parallelize(ParallelType::BIDx);
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Invalid TileRasterizationOrder passed to Matmul scheduler");
}
c_smem->axis(axis++)->parallelize(ParallelType::Serial);
c_smem->axis(axis++)->parallelize(ParallelType::TIDz);
c_smem->axis(axis++)->parallelize(ParallelType::TIDy);
c_smem->axis(axis++)->parallelize(ParallelType::Serial);
c_smem->axis(axis++)->parallelize(ParallelType::Serial);
c_smem->axis(axis++)->parallelize(ParallelType::TIDx);
c_smem->axis(axis++)->parallelize(ParallelType::Vectorize);
c_smem->axis(-1)->parallelize(ParallelType::Vectorize);
}

void mergeBackAfterSwizzleTransform(
Expand All @@ -555,9 +620,6 @@ void scheduleEpilog(

// Swizzle the shared memory data layout
swizzleSharedMemory(c_smem, params, 0);

// Actual schedule
schedule_epilogue_tensor(c_smem, gemm_tile);
}
} // namespace

Expand Down Expand Up @@ -869,8 +931,19 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {

if (params.has_smem_epilogue) {
scheduleEpilog(c_smem, mma_result, params, gemm_tile);
schedule_output_tensor(
c, gemm_tile.warp_tile.m, gemm_tile.instruction_tile.m);
const auto& mma_options =
mma_builder.operand(MmaOptions::Operand::Accumulator).build();
schedule_epilogue_tensor(c_smem, params.cta_order, gemm_tile, mma_options);
// schedule_output_tensor(
// c, params.cta_order, gemm_tile.warp_tile.m,
// gemm_tile.instruction_tile.m);
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
c_smem,
-1,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
} else {
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
mma_result,
Expand All @@ -879,9 +952,9 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
// Always vector
c->axis(-1)->parallelize(ParallelType::Vectorize);
}
// Always vector
c->axis(-1)->parallelize(ParallelType::Vectorize);
// auto inline for all tensors except register tensors and output tensor
inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb, c_smem, c}));

Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class MatmulParams : public HeuristicParams {

//! Swizzle MMA results in shared memory
//! coalesced write to global memory
bool has_smem_epilogue = false;
bool has_smem_epilogue = true;

std::string toString() const override {
std::stringstream ss;
Expand Down
5 changes: 5 additions & 0 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ std::shared_ptr<MatmulParams> getMatmulHeuristics(
// Disable magic zero for matmul kernels
params->cparams.enable_magic_zero = false;

// Check if we have enough shared memory for epilogue
params->has_smem_epilogue = mma_utils::hasEnoughSharedMemoryForEpilogue(
params->tile_sizes,
params->double_buffer_options.smem_double_buffer_stage);

if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) {
printMsg(params->toString());
}
Expand Down
31 changes: 29 additions & 2 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on

#include <ATen/cuda/CUDAContext.h>
#include <device_lower/utils.h>
#include <expr_evaluator.h>
#include <ir/printer.h>
Expand All @@ -14,11 +15,35 @@
#include <scheduler/utils.h>
#include <variant>
#include "mma_type.h"

namespace nvfuser {

namespace mma_utils {

bool hasEnoughSharedMemoryForEpilogue(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage) {
auto properties = at::cuda::getDeviceProperties(
c10::Device(c10::DeviceType::CUDA, 0).index());
const int64_t device_smem_limit = (int64_t)properties->sharedMemPerBlockOptin;

// see scheduleContiguousVectorLoad
const int64_t vector_word = 8;
auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile;
const int64_t round_to_factor =
warp_dims.m * warp_dims.n * warp_dims.k * 32 * vector_word;
const int64_t mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k;
const int64_t nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k;
const int64_t smem_a = ceilDiv(mk, round_to_factor) * round_to_factor *
dataTypeSize(DataType::Half) * smem_double_buffer_stage;
const int64_t smem_b = ceilDiv(nk, round_to_factor) * round_to_factor *
dataTypeSize(DataType::Half) * smem_double_buffer_stage;
const int64_t smem_c = gemm_tile.cta_tile.m * gemm_tile.cta_tile.n *
dataTypeSize(DataType::Float);
int64_t smem_size = smem_a + smem_b + smem_c;

return smem_size <= device_smem_limit;
}

void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
// Assumes
// [M, N, K]
Expand Down Expand Up @@ -430,7 +455,9 @@ void checkDimSize(
":",
id->extent()->evaluateInt(),
"vs",
expect[axis_index]);
expect[axis_index],
"\n for tv: ",
tv->toString());
}
}

Expand Down
5 changes: 5 additions & 0 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ namespace nvfuser {

namespace mma_utils {

//! Check if there is enough shared memory for the given tile options
TORCH_CUDA_CU_API bool hasEnoughSharedMemoryForEpilogue(
const MatMulTileOptions& gemm_tile,
const int smem_double_buffer_stage);

//! Utilities in this namespace facilitates scheduling matmul kernels with
//! hierarchichal tiling specified in MatMulTileOptions.

Expand Down
Loading

0 comments on commit 30e0791

Please sign in to comment.