diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index fed634fc487..7e93d599eac 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -93,8 +93,10 @@ void swizzleSharedMemory( check_concrete_static_dim(shared_mem_tv->axis(-1 - shift)); // Extract the constant sizes of the swizzled tile - const int64_t tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt(); - const int64_t tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt(); + const int64_t tile_size_x = + shared_mem_tv->axis(-2 - shift)->extent()->evaluateInt(); + const int64_t tile_size_y = + shared_mem_tv->axis(-1 - shift)->extent()->evaluateInt(); if (isTuring(params.mma_macro) || isAmpere(params.mma_macro)) { // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. @@ -215,7 +217,6 @@ void swizzleSharedMemory( // assert(row_stride >= 0); // assert(num_megabanks >= 0); int64_t row_stride_znz = row_stride % num_megabanks; - /* Consider the following function in Z/nZ: * f(i; init) = init + i * stride * where init is the initial position of the pointer in the clock when we @@ -366,7 +367,18 @@ void swizzleSharedMemory( * 7| | | * +----------+----------+ */ - + int64_t swizzle_period = n_rows / repeated_pattern_size; + // tile_size_y will be splitted by n_cols and then by swizzle_period + // avoid over split, won't fully remove bank conflict if happens. + if (tile_size_y < n_cols * swizzle_period) { + swizzle_period = tile_size_y / n_cols; + repeated_pattern_size = n_rows / swizzle_period; + } + // e.g. tile_size_y = 96 in FusionAmpereMatmulTileCheck4warp_CUDA + while (tile_size_y / n_cols % swizzle_period) { + swizzle_period /= 2; + repeated_pattern_size = n_rows / swizzle_period; + } // -2 -1 // [row, col] TORCH_INTERNAL_ASSERT( @@ -385,7 +397,6 @@ void swizzleSharedMemory( } // -5 -4 -3 -2 -1 // [matrix id, repeat, pattern, matrix id, matrix] - int64_t swizzle_period = n_rows / repeated_pattern_size; if (!shift) { TORCH_INTERNAL_ASSERT( tile_size_y % (swizzle_period * n_cols) == 0, @@ -443,7 +454,6 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { // Swizzle the shared memory data layout swizzleSharedMemory(shared_mem_tv, params, 0); - // Assuming we are always vectorizing smem write by 128b at the moment: // TODO: would need a data-type and alignment dependent interface // to support non-vectorizable shapes. @@ -465,6 +475,7 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) { void schedule_output_tensor( TensorView* c, + MatmulParams::TileRasterizationOrder cta_order, int warp_tile_m, int instruction_tile_m) { // [a,b,128,128] @@ -483,8 +494,19 @@ void schedule_output_tensor( c->reorder({{2, 3}, {3, 2}}); //[a,b,wm/im, 128/wm, im/2, 2, 128/4, 4] int axis = 0; - c->axis(axis++)->parallelize(ParallelType::BIDx); - c->axis(axis++)->parallelize(ParallelType::BIDy); + switch (cta_order) { + case MatmulParams::TileRasterizationOrder::RowMajor: + c->axis(axis++)->parallelize(ParallelType::BIDx); + c->axis(axis++)->parallelize(ParallelType::BIDy); + break; + case MatmulParams::TileRasterizationOrder::ColumnMajor: + c->axis(axis++)->parallelize(ParallelType::BIDy); + c->axis(axis++)->parallelize(ParallelType::BIDx); + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Invalid TileRasterizationOrder passed to Matmul scheduler"); + } c->axis(axis++)->parallelize(ParallelType::Serial); c->axis(axis++)->parallelize(ParallelType::TIDz); c->axis(axis++)->parallelize(ParallelType::Serial); @@ -495,7 +517,9 @@ void schedule_output_tensor( void schedule_epilogue_tensor( TensorView* c_smem, - const MatMulTileOptions& gemm_tile) { + MatmulParams::TileRasterizationOrder cta_order, + const MatMulTileOptions& gemm_tile, + const MmaOptions& mma_opts) { auto warp_tile = gemm_tile.warp_tile; auto instruction_tile = gemm_tile.instruction_tile; // transform to its producer, mma results @@ -513,23 +537,64 @@ void schedule_epilogue_tensor( c_smem->reorder({{2, 3}, {3, 2}, {4, 6}, {5, 4}, {6, 5}, {7, 7}}); //[a,b,64/16,128/64,128/64,64/8, 16, 8] - // MMA - c_smem->split(-2, 8); - c_smem->split(-1, 2); - //[a,b,64/16,128/64,128/64,64/8, 16/8, 8, 8/2, 2] - c_smem->merge(-3, -2); + auto macro = mma_opts.macro; + int m_pos = -2; + switch (macro) { + case MmaOptions::MacroType::Volta_16_16_4: + break; + case MmaOptions::MacroType::Turing_16_8_16: + case MmaOptions::MacroType::Ampere_16_8_16: + // [16, 8] + c_smem->split(-2, 8); + c_smem->split(-1, 2); + c_smem->merge(-3, -2); + c_smem->axis(m_pos)->parallelize(ParallelType::TIDx); + break; + case MmaOptions::MacroType::Turing_16_16_16: + case MmaOptions::MacroType::Ampere_16_16_16: + // m + // [16, 16 (,R)] + c_smem->split(m_pos + 1, 8); + // m + // [16, n2, 8 (,R)] + c_smem->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}}); + + // m + // [n2, 16, 8 (,R)] + c_smem->split(m_pos, 8); + c_smem->split(m_pos + 1, 2); + + // m + // [2o, 8o, 4i, 2i (,R)] + c_smem->merge(m_pos - 1); + c_smem->axis(m_pos)->parallelize(ParallelType::TIDx); + + break; + default: + TORCH_CHECK( + false, "scheduleMmaWarp: unsupported mma option ", toString(macro)); + break; + } // parallel int axis = 0; - c_smem->axis(axis++)->parallelize(ParallelType::BIDx); - c_smem->axis(axis++)->parallelize(ParallelType::BIDy); + switch (cta_order) { + case MatmulParams::TileRasterizationOrder::RowMajor: + c_smem->axis(axis++)->parallelize(ParallelType::BIDx); + c_smem->axis(axis++)->parallelize(ParallelType::BIDy); + break; + case MatmulParams::TileRasterizationOrder::ColumnMajor: + c_smem->axis(axis++)->parallelize(ParallelType::BIDy); + c_smem->axis(axis++)->parallelize(ParallelType::BIDx); + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Invalid TileRasterizationOrder passed to Matmul scheduler"); + } c_smem->axis(axis++)->parallelize(ParallelType::Serial); c_smem->axis(axis++)->parallelize(ParallelType::TIDz); c_smem->axis(axis++)->parallelize(ParallelType::TIDy); - c_smem->axis(axis++)->parallelize(ParallelType::Serial); - c_smem->axis(axis++)->parallelize(ParallelType::Serial); - c_smem->axis(axis++)->parallelize(ParallelType::TIDx); - c_smem->axis(axis++)->parallelize(ParallelType::Vectorize); + c_smem->axis(-1)->parallelize(ParallelType::Vectorize); } void mergeBackAfterSwizzleTransform( @@ -555,9 +620,6 @@ void scheduleEpilog( // Swizzle the shared memory data layout swizzleSharedMemory(c_smem, params, 0); - - // Actual schedule - schedule_epilogue_tensor(c_smem, gemm_tile); } } // namespace @@ -869,8 +931,19 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { if (params.has_smem_epilogue) { scheduleEpilog(c_smem, mma_result, params, gemm_tile); - schedule_output_tensor( - c, gemm_tile.warp_tile.m, gemm_tile.instruction_tile.m); + const auto& mma_options = + mma_builder.operand(MmaOptions::Operand::Accumulator).build(); + schedule_epilogue_tensor(c_smem, params.cta_order, gemm_tile, mma_options); + // schedule_output_tensor( + // c, params.cta_order, gemm_tile.warp_tile.m, + // gemm_tile.instruction_tile.m); + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + c_smem, + -1, + {c}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); } else { scheduler_utils::BoundedDirectionalTransformPropagator::forward( mma_result, @@ -879,9 +952,9 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { scheduler_utils::BoundedDirectionalTransformPropagator::Options() .propagateParallelType() .propagateToBoundary()); - // Always vector - c->axis(-1)->parallelize(ParallelType::Vectorize); } + // Always vector + c->axis(-1)->parallelize(ParallelType::Vectorize); // auto inline for all tensors except register tensors and output tensor inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb, c_smem, c})); diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 95aec144f1f..9ad94d236ea 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -92,7 +92,7 @@ class MatmulParams : public HeuristicParams { //! Swizzle MMA results in shared memory //! coalesced write to global memory - bool has_smem_epilogue = false; + bool has_smem_epilogue = true; std::string toString() const override { std::stringstream ss; diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index b4ea0f04d98..e28f1b00fd4 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -377,6 +377,11 @@ std::shared_ptr getMatmulHeuristics( // Disable magic zero for matmul kernels params->cparams.enable_magic_zero = false; + // Check if we have enough shared memory for epilogue + params->has_smem_epilogue = mma_utils::hasEnoughSharedMemoryForEpilogue( + params->tile_sizes, + params->double_buffer_options.smem_double_buffer_stage); + if (isDebugDumpEnabled(DebugDumpOption::MatmulChecks)) { printMsg(params->toString()); } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2aa8a70a5e8..1718b50ba79 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -6,6 +6,7 @@ */ // clang-format on +#include #include #include #include @@ -14,11 +15,35 @@ #include #include #include "mma_type.h" - namespace nvfuser { namespace mma_utils { +bool hasEnoughSharedMemoryForEpilogue( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage) { + auto properties = at::cuda::getDeviceProperties( + c10::Device(c10::DeviceType::CUDA, 0).index()); + const int64_t device_smem_limit = (int64_t)properties->sharedMemPerBlockOptin; + + // see scheduleContiguousVectorLoad + const int64_t vector_word = 8; + auto warp_dims = gemm_tile.cta_tile / gemm_tile.warp_tile; + const int64_t round_to_factor = + warp_dims.m * warp_dims.n * warp_dims.k * 32 * vector_word; + const int64_t mk = gemm_tile.cta_tile.m * gemm_tile.cta_tile.k; + const int64_t nk = gemm_tile.cta_tile.n * gemm_tile.cta_tile.k; + const int64_t smem_a = ceilDiv(mk, round_to_factor) * round_to_factor * + dataTypeSize(DataType::Half) * smem_double_buffer_stage; + const int64_t smem_b = ceilDiv(nk, round_to_factor) * round_to_factor * + dataTypeSize(DataType::Half) * smem_double_buffer_stage; + const int64_t smem_c = gemm_tile.cta_tile.m * gemm_tile.cta_tile.n * + dataTypeSize(DataType::Float); + int64_t smem_size = smem_a + smem_b + smem_c; + + return smem_size <= device_smem_limit; +} + void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { // Assumes // [M, N, K] @@ -430,7 +455,9 @@ void checkDimSize( ":", id->extent()->evaluateInt(), "vs", - expect[axis_index]); + expect[axis_index], + "\n for tv: ", + tv->toString()); } } diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 8a6dec5df96..cdf9f163707 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -17,6 +17,11 @@ namespace nvfuser { namespace mma_utils { +//! Check if there is enough shared memory for the given tile options +TORCH_CUDA_CU_API bool hasEnoughSharedMemoryForEpilogue( + const MatMulTileOptions& gemm_tile, + const int smem_double_buffer_stage); + //! Utilities in this namespace facilitates scheduling matmul kernels with //! hierarchichal tiling specified in MatMulTileOptions. diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 164af10beb8..b1d7ae92758 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -56,6 +56,8 @@ namespace nvfuser { using namespace at::indexing; +namespace MatMulUtils {} + // MMA unit test for a single instruction tile. VoltaTT TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { Fusion fusion; @@ -3153,7 +3155,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { gemm_tile.cta_tile = GemmTile(128, 128, 64); gemm_tile.warp_tile = GemmTile(64, 64, 64); gemm_tile.instruction_tile = GemmTile(16, 16, 16); - MatmulParams params; params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16; params.tile_sizes = gemm_tile; @@ -3262,6 +3263,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; + params.has_smem_epilogue = mma_utils::hasEnoughSharedMemoryForEpilogue( + gemm_tile, params.double_buffer_options.smem_double_buffer_stage); scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); @@ -3275,7 +3278,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { {inputs.first, inputs.second}, LaunchParams(), matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + // if cta_tile.n can't be fully divided by 64 (n_cols * + // swizzle_period), then we can't fully remove the bank conflict, see + // swizzleSharedMemory + const bool expected_bank_conflict = + params.has_smem_epilogue && gemm_tile.cta_tile.n % 64 > 0; + if (!expected_bank_conflict) { + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + } auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -3289,6 +3299,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { k_size); } } + break; } } @@ -3325,6 +3336,10 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; + params.has_smem_epilogue = + mma_utils::hasEnoughSharedMemoryForEpilogue( + gemm_tile, + params.double_buffer_options.smem_double_buffer_stage); scheduleMatmul(&fusion, params); @@ -3339,7 +3354,14 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck8warp_CUDA) { {inputs.first, inputs.second}, LaunchParams(), matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + // if cta_tile.n can't be fully divided by 64 (n_cols * + // swizzle_period), then we can't fully remove the bank conflict, see + // swizzleSharedMemory + const bool expected_bank_conflict = + params.has_smem_epilogue && gemm_tile.cta_tile.n % 64 > 0; + if (!expected_bank_conflict) { + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + } auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), @@ -3383,7 +3405,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { params.double_buffer_options.double_buffer_smem_write = true; params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 2; - + params.has_smem_epilogue = mma_utils::hasEnoughSharedMemoryForEpilogue( + gemm_tile, params.double_buffer_options.smem_double_buffer_stage); scheduleMatmul(&fusion, params); auto inputs = matmulAtInput(M, N, K, layout); @@ -3397,7 +3420,11 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck6warp_CUDA) { {inputs.first, inputs.second}, LaunchParams(), matmul_cparams)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + const bool expected_bank_conflict = + params.has_smem_epilogue && gemm_tile.cta_tile.n % 64 > 0; + if (!expected_bank_conflict) { + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + } auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -4094,15 +4121,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulEpilogue_CUDA) { inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); // check bank conflicts - const auto& bank_conflict = getBankConflictInfo(fe.kernel()); - if (!bank_conflict.empty()) { - for (auto it = bank_conflict.begin(); it != bank_conflict.end(); it++) { - std::cout << "Bank conflict expression: " << it->first->toString() - << "read conflict= " << it->second.first - << ", write conflict= " << it->second.second << std::endl; - } - ASSERT_TRUE(bank_conflict.empty()); - } + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); break; }