From 7525cc0bac80a4d21a31d2d77bfd3c7f7580d079 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Fri, 22 Mar 2024 08:01:01 -0400 Subject: [PATCH] Unswizzle before grid reduction in split-K (#1534) Serial grid reductions are used in split-K matmuls as of #1510. This means we load and store elements in the reduction tensor according to the indexing of the work buffer. This is unlike ordinary grid reductions that use `gridReduce`, which reduces individual elements using a scheme that ensures coalescing by indexing into the work buffer based on `threadIdx` and `blockIdx`. Currently these split-K accesses are inefficient due to this lack of coalescing. We currently already ensure coalesced output stores in matmuls when possible by using smem for the epilogue (#387). A shared memory buffer is used to communicate elements between threads so that the resulting tensor will have a proper global access pattern when it is written out to global memory as a tile of the output. Before this PR if we used split-K with `use_smem_epilogue = true`, the store to global memory will be coalesced but there will be uncoalesced accesses during the split-K reduction. This PR modifies scheduling so that in those cases, the smem epilogue tensor is placed before the split-K sum, so that unswizzling happens before completing the reduction. The result is that the reduction accesses are coalesced. This is a generated kernel from `NVFuserTest.FusionAmpereMatmulSplitKBias_CUDA`: ```c++ // ... (main loop) ... #pragma unroll for(nvfuser_index_t i59 = 0; i59 < 4LL; ++i59) { nvfuser_index_t i104; i104 = 8LL * i59; nvfuser_index_t i105; i105 = 32LL * i59; #pragma unroll for(nvfuser_index_t i61 = 0; i61 < 8LL; ++i61) { nvfuser_index_t i106; i106 = 4LL * i61; asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" :"=f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[0]), "=f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[1]), "=f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[2]), "=f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[3]) :"r"((*reinterpret_cast*>(&T4[i104]))[0]), "r"((*reinterpret_cast*>(&T4[i104]))[1]), "r"((*reinterpret_cast*>(&T4[i104]))[2]), "r"((*reinterpret_cast*>(&T4[i104]))[3]), "r"((*reinterpret_cast*>(&T5[i106]))[0]), "r"((*reinterpret_cast*>(&T5[i106]))[1]), "f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[0]), "f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[1]), "f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[2]), "f"((*reinterpret_cast*>(&T16[(i105 + i106)]))[3]) ); } } } NVFUSER_UPDATE_MAGIC_ZERO; __syncthreads(); } __syncthreads(); #pragma unroll for(nvfuser_index_t i107 = 0; i107 < 4LL; ++i107) { nvfuser_index_t i108; i108 = 32LL * i107; nvfuser_index_t i109; i109 = i38 + (2048LL * i107); #pragma unroll for(nvfuser_index_t i110 = 0; i110 < 8LL; ++i110) { nvfuser_index_t i111; i111 = i108 + (4LL * i110); nvfuser_index_t i112; i112 = i11 + i110; nvfuser_index_t i113; i113 = (i109 + (32LL * (i112 / 4LL))) + (8LL * (i39 ^ (i112 % 4LL))); #pragma unroll for(nvfuser_index_t i114 = 0; i114 < 2LL; ++i114) { loadGeneric( &T17[(i113 + (1024LL * i114))], &T16[(i111 + (2LL * i114))]); } } } NVFUSER_UPDATE_MAGIC_ZERO; // Allocate global tensor T19 grid_sync::blockSerializeWait(&T19[index_utils::maskedOffset(blockIdx, gridDim)]); __syncthreads(); #pragma unroll for(nvfuser_index_t i115 = 0; i115 < 32LL; ++i115) { nvfuser_index_t i116; i116 = i115 + nvfuser_zero; nvfuser_index_t i117; i117 = i44 + (i45 * i116); nvfuser_index_t i118; i118 = i47 + (4LL * i115); bool b119; b119 = i55 < (-(4LL * i116)); bool b120; b120 = b54 && b119; Array T6; T6.set(float(0.000000000e+00f)); // Allocate global tensor T20 reduction::serialReductionStep( &T6[0LL], &T17[(i42 + (512LL * i115))], 0.000000000e+00f, &T20[i117], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset(blockIdx, gridDim) == 0, index_utils::maskedOffset(blockIdx, gridDim) == index_utils::maskedSize(gridDim) - 1, b120, b120); Array T10; #pragma unroll for(nvfuser_index_t i121 = 0; i121 < 4LL; ++i121) { __half T18[1LL]; T18[0LL] = 0LL; if (b119) { T18[0LL] = T2[(i118 + ((i48 + (i121 + nvfuser_zero)) / 128LL))]; } __half T7[1LL]; T7[0LL] = T18[0LL]; float T8[1LL]; T8[0LL] = __half2float(T7[0LL]); T10[i121] = T6[i121] + T8[0LL]; } if ((b56 && b119)) { loadLocalToGlobal( &T9[i117], &T10[0LL]); } } NVFUSER_UPDATE_MAGIC_ZERO; grid_sync::blockSerializeRelease(&T19[index_utils::maskedOffset(blockIdx, gridDim)]); } ``` Note that the `i135` loop will be smaller once we have #1528 at which point it would more clearly show reduction followed by the loop for the predicated bias epilogue. (Diff should be viewed hiding whitespace changes as many changes are to indentation). --- benchmarks/cpp/matmul.cpp | 50 +++- csrc/device_lower/validation.cpp | 14 +- csrc/scheduler/matmul.cpp | 193 ++++++++++----- tests/cpp/test_gpu_tensorcore.cpp | 385 +++++++++++++++++------------- 4 files changed, 402 insertions(+), 240 deletions(-) diff --git a/benchmarks/cpp/matmul.cpp b/benchmarks/cpp/matmul.cpp index bcfdc132bf3..c5068455bca 100644 --- a/benchmarks/cpp/matmul.cpp +++ b/benchmarks/cpp/matmul.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -145,6 +146,9 @@ static void SingleMatmulBase( int64_t n = benchmark_state.range(1); int64_t k = benchmark_state.range(2); + // inputs + at::manual_seed(0); + // Tensor inputs auto inputs = matmulAtInput2D(m, n, k, layout); auto expected_output = atMatmul( @@ -159,9 +163,6 @@ static void SingleMatmulBase( preseg_passes::OptimizationPass::runPass(fusion); - // inputs - at::manual_seed(0); - KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder( {inputs.first, inputs.second}); @@ -250,6 +251,14 @@ MatmulParams getMatmulParams( params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = stage_number; params.splitk_factor = splitk_factor; + std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) = + mma_utils::generateSharedMemoryEpilogueHeuristics( + gemm_tile, + stage_number, + {DataType::Half, DataType::Half, DataType::Float}, + /*smem_a_reuse_guaranteed=*/true, + /*smem_b_reuse_guaranteed=*/true, + /*ignore_occupancy_drop=*/true); return params; } @@ -362,7 +371,8 @@ static void NvFuserScheduler_Matmul( benchmark::State& benchmark_state, MmaLayout layout, int splitk_factor = 1, - bool partitionedk = false) { + bool partitionedk = false, + bool use_smem_epilogue = false) { int num_warps = benchmark_state.range(3); int number_of_stage = benchmark_state.range(4); @@ -376,6 +386,15 @@ static void NvFuserScheduler_Matmul( auto params = getMatmulParams( cta_tile, number_of_stage, layout, partitionedk ? 1 : splitk_factor); + if (use_smem_epilogue) { + if (!params.use_smem_epilogue) { + benchmark_state.SkipWithError( + "Insufficient shared mem for smem epilogue"); + } + } else { + params.use_smem_epilogue = false; + params.promote_prologue_smem_reuse = false; + } NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( 8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state); @@ -630,13 +649,23 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) { // Use this for manual splitk. static void MatmulShapeWarpStageSpecificSplitK( benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "K", "warps", "stages", "splitk_factor"}); + b->ArgNames( + {"M", "N", "K", "warps", "stages", "splitk_factor", "smem_epilogue"}); for (long int num_warps : NumWarps) { for (long int num_stages : NumStages) { for (auto [m, n, k] : std::vector>(SplitKSpecificShapes)) { for (auto splitk_factor : {2, 3, 4, 5, 6}) { - b->Args({m, n, k, num_warps, num_stages, splitk_factor}); + for (bool use_smem_epilogue : {false, true}) { + b->Args( + {m, + n, + k, + num_warps, + num_stages, + splitk_factor, + use_smem_epilogue}); + } } } } @@ -727,8 +756,13 @@ static void NvFuserScheduler_Matmul_Manual( benchmark::State& benchmark_state, MmaLayout layout) { int splitk_factor = benchmark_state.range(5); + bool use_smem_epilogue = benchmark_state.range(6); NvFuserScheduler_Matmul( - benchmark_state, layout, splitk_factor, /*partitionedk=*/false); + benchmark_state, + layout, + splitk_factor, + /*partitionedk=*/false, + use_smem_epilogue); } #define SpecificSplitKBenchmark(layout) \ @@ -742,9 +776,7 @@ static void NvFuserScheduler_Matmul_Manual( ForAllLayouts(EagerModeBenchmark); ForAllLayouts(NvfuserMatmulBenchmark); -ForAllLayouts(AutoSplitKBenchmark); ForAllLayouts(SpecificSplitKBenchmark); -ForAllLayouts(AutoPartitionedKBenchmark); // Note: SplitK Reduction benchmarks are parametrized only by M, N. The splitk // factor is deduced automatically from N diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index de055538a3e..2a9be0f9619 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -388,6 +388,16 @@ class VectorizeValidator : public OptInDispatch { if (r_id->isReduction() || r_id->isBroadcast()) { continue; } + if ((tv->getMemoryType() == MemoryType::Shared || + tv->getMemoryType() == MemoryType::Local) && + r_id->isBlockDim()) { + // Inner-most parallelized dimensions don't count in allocation of + // shared and local tensors. + continue; + } + if (tv->getMemoryType() == MemoryType::Local && r_id->isThreadDim()) { + continue; + } last_alloc_dim = r_id; last_alloc_dim_pos = i - 1; break; @@ -416,9 +426,9 @@ class VectorizeValidator : public OptInDispatch { ", allocation domain: ", ir_utils::toString(tv->getMaybeAllocationDomain()), ", vectorized id: ", - validator.vectorized_id_, + validator.vectorized_id_->toString(), ", innermost id: ", - last_alloc_dim, + last_alloc_dim->toString(), ", contiguity: ", contiguity.has_value() ? (*contiguity ? "t" : "f") : "n"); } diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index afe6597fcac..f4917a5e663 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -687,6 +687,51 @@ void scheduleFusionInputsForEpilogue( } } +void scheduleSplitKSum( + TensorView* splitk_sum, + const int num_batch_dims, // TODO: this should not be needed + bool use_smem_epilogue) { + if (splitk_sum == nullptr) { + // This indicates no split-K was used + return; + } + + // Always use serial grid reduction for split-K sum + splitk_sum->definition()->as()->requestSerialGridReduction(); + + if (use_smem_epilogue) { + // Now that transforms are propagated backward to smem_epilogue, which is + // before splitk_sum, we can vectorize the inner-most non-trivial + // dimension of splitk_sum + // + // Note that the split-K reduction is the inner-most dimension. + Val* vec_ext = splitk_sum->axis(-2)->extent(); + NVF_ERROR(vec_ext->isConstInt()); + int64_t vec_ext_int = vec_ext->evaluate().as(); + splitk_sum->axis(-1)->parallelize(ParallelType::BIDz); + splitk_sum->axis(-3)->parallelize(ParallelType::TIDx); + if (vec_ext_int * dataTypeSize(splitk_sum->dtype()) > 16) { + // NOTE: We might encounter an illegal vectorization size if we are + // using Float for this reduction and Half for output. So here we first + // check whether the vectorize size is at most 16 bytes. If not, then we + // split into an unrolled loop that will do multiple vectorized + // reads/writes instead. Note that we reorder such that the axes are in + // order UR TIDx V. + splitk_sum->split( + -2, 16 / dataTypeSize(splitk_sum->dtype()), /*inner_split=*/true); + splitk_sum->axis(-3)->parallelize(ParallelType::Unroll); + splitk_sum->reorder({{-4, -3}}); + // In this case, we have [... iUR iTx rBz iS] + } + splitk_sum->reorder({{-2, -1}}); + } else { // no smem epilogue + // Reorder to place the split-K reduction next to innermost [... rBz iS] + splitk_sum->reorder({{-9, -2}}); + } + // Vectorize inner-most dimension [... (iUR iTx) rBz iV] + splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize); +} + } // namespace void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { @@ -797,10 +842,6 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // Setup accumulator register. auto mma_result = mma->out()->as(); - // Unswizzle mma result in shared memory - auto smem_epilogue = - params.use_smem_epilogue ? mma_result->cacheAfter() : mma_result; - // TODO: // Significant build out needed here // for more flexibility and data type support. @@ -867,6 +908,13 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // [... M,N,K] mma_utils::makeTile(mma_result, gemm_tile.cta_tile.toVector()); + // [..., Mo, No, Ko, Mi, Ni, Ki] + + // Unswizzle mma result in shared memory + // Note that if we are using split-K, we will set up this buffer after + // rfactoring the matmul, between the MmaOp and the ReductionOp, in order to + // take advantage of unswizzling during the grid reduction + TensorView* smem_epilogue = mma_result; // Swizzle block tiles: if (params.grid_swizzle_factor != 1) { @@ -889,15 +937,15 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { } } - // [..., Mo, No, Koo, Mi, Ni, Ki] + // [..., iMo, iNo, rKo, iMi, iNi, rKi] int num_splitk_dims = 0; TensorView* splitk_sum = nullptr; if (params.splitk_factor != 1) { - // Split Koo -> [Kf, Ko] + // Split Ko -> [rKf, rKg] mma_result->split(-4, params.splitk_factor, /*inner*/ false); - // After split [..., Mo, No, Kf, Ko, Mi, Ni, Ki] + // After split [..., iMo, iNo, rKf, rKg, iMi, iNi, rKi] // rFactor converts - // mma_result = mma(A, B, {/*Kf*/-5, /*Ko*/-4, /*Ki*/-1}); + // mma_result = mma(A, B, {/*Kf*/-5, /*Kg*/-4, /*Ki*/-1}); // to // intermediate = mma(A, B, {-4, -1}); // final_sum = sum(intermediate, {/*Kf*/-3}); @@ -906,11 +954,26 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { splitk_sum = mma_result; mma_result = splitk_sum->rFactor({-4, -1}); - splitk_sum->definition()->as()->requestSerialGridReduction(); - num_splitk_dims = 1; } + // At this point we have the following schedule: + // No split-K + // mma_result [..., iMo, iNo, rKo, iMi, iNi, rKi] + // Split-K + // mma_result [..., iMo, iNo, iKf, rKg, iMi, iNi, rKi] + // splitk_sum [..., iMo, iNo, rKf, iMi, iNi] + + if (params.use_smem_epilogue) { + // Note that for split-K + // splitk_sum = sum(mma_result) + // becomes + // smem_epilogue = set(mma_result) + // splitk_sum = sum(smem_epilogue) + smem_epilogue = mma_result->cacheAfter(); + // smem_epilogue = [..., iMo, iNo, iKf, iMi, iNi] + } + // Propagate tiling globally scheduler_utils::transformPropagateToAllFrom(mma_result, -1); @@ -922,19 +985,38 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { } // Schedule warp tile + // Incoming mma_result = [... iMo iNo (iKf) rKg iMi iNi rKi] mma_utils::scheduleWarpTileWithReduction(mma_result, gemm_tile); - // [..., Mo, No, (Kf,) Ko, Kw, Mwo, Nwo, Mwi, Nwi, Mi, Ni, Ki] + // After scheduling warp tile, the last three dimensions are split and + // rearranged: + // -3 -2 -1 + // [... M N K] + // maps to + // -8 -7 -6 -5 -4 -3 -2 -1 + // [... Kwo Mwo Nwo Mw Nw Mi Ni Ki] + // so now + // -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 + // mma_result = [... iMo iNo (iKf) rKg rKwo iMwo iNwo iMw iNw iMin iNin rKin] + // splitk_sum = [... iMo iNo rKf iMi iNi] // Propagate warp tile to main loop and epilog/output tvs scheduler_utils::BoundedDirectionalTransformPropagator::bothWays( mma_result, -1, {acw_smem, bcw_smem}, {smem_epilogue}); + // No (cross-CTA) split-K + // mma_result [..., iMo iNo rKo rKwo iMwo iNwo iMw iNw iMin iNin rKin] + // smem_epilogue (unscheduled, same as original or current mma_result) + // splitk_sum (nullptr) + // + // With split-K + // mma_result [... iMo iNo iKf rKg rKwo iMwo iNwo iMw iNw iMin iNin rKin] + // splitk_sum [... iMo iNo rKf iMi iNi] + // Schedule prolog: // TODO: this section needs more configurability. // ------------------------------------------------------------------ scheduleProlog(acw_smem, params); scheduleProlog(bcw_smem, params); - // [..., Mo, No, (Kf,) Ko, Kw, Mwo, Nwo, Mwi, Nwi, Mi, Ni, Ki] // Get the input to the mma op. mma = mma_result->definition()->as(); @@ -976,6 +1058,14 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { }; propagate_mma_input_schedule_to(acw_smem, bcw_smem); + // This does a split-reorder-merge swizzle of the last two M and N dimensions + // (and a possible final reduction dim). + // eg. [M64, N24, R] -> [WarpGroup128, N3, M2, N2, Ro, R4, R2] + // Before + // mma_result [... iMo iNo (iKf) rKg rKwo iMwo iNwo iMw iNw iMin iNin rKin] + // After + // mma_result [... iMo iNo (iKf) rKg rKwo iMwo iNwo iMw + // iNw iMino iNino iMin2 iNin2 rKino rKin4 rKin2] mma_result->applyMmaSwizzle(MmaOperand::Accumulator); // Set parallelization: @@ -1011,20 +1101,35 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { // - B: block // - T: thread // - S: serial. This will become a for loop in the generated kernel - // - iMMA: unconracted axis in an MMA tensor core operation. + // - iMMA: uncontracted axis in an MMA tensor core operation. // - rMMA: contract in an MMA tensor core operation. // - // with splitk: - // nbatch + 1 2 3 4 5 6 7 8 9 10 11 12 - // -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 - // [..., Mo, No, Kf, Ko, Kw, Mwo, Nwo, Mwi, Nwi, MNi1, MNi2, MNi3, Ki] - // (iS) iBx iBy rBz rS rS iTz iTy iS iS iMMA iTx iMMA rMMA + // With split-K: + // mma_result + // nbatch + 1 2 3 4 5 6 7 8 + // -15 -14 -13 -12 -11 -10 -9 -8 + // [... iMo iNo (iKf) rKg rKwo iMwo iNwo iMw iNw ... + // iBx iBy iBz rS rS iTz iTy iS iS + // 9 10 11 12 13 14 15 + // -7 -6 -5 -4 -3 -2 -1 + // ... iMino iNino iMin2 iNin2 rKino rKin4 rKin2] + // iTx iMMA iMMA iMMA rMMA rMMA rMMA + // smem_epilogue (unscheduled, same as original mma_result) + // splitk_sum (nullptr) // - // without splitk: - // nbatch + 1 2 3 4 5 6 7 8 9 10 11 - // -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1 - // [..., Mo, No, Ko, Kw, Mwo, Nwo, Mwi, Nwi, MNi1, MNi2, MNi3, Ki] - // (iBz) iBx iBy rS rS iTz iTy iS iS iMMA iTx iMMA rMMA + // Without split-K: + // mma_result + // nbatch + 1 2 3 4 5 6 7 8 + // -14 -13 -12 -11 -10 -9 -8 -7 + // [... iMo iNo rKg rKwo iMwo iNwo iMw iNw iMino + // (iBz) iBx iBy rS rS iTz iTy iS iS iTx + // 9 10 11 12 13 14 + // -6 -5 -4 -3 -2 -1 + // iNino iMin2 iNin2 rKino rKin4 rKin2] + // iMMA iMMA iMMA rMMA rMMA rMMA + // smem_epilogue (unscheduled, same as original mma_result) + // splitk_sum + // [... iMo iNo rKf iMi iNi] // When we have both batch dims and splitk, parallelize splitk only. // If we only have batch dim, parallelize the batch dim. @@ -1101,45 +1206,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { scheduleFusionInputsForEpilogue(roles_map, params.use_smem_epilogue); } - if (num_splitk_dims) { - // Here we reorder splitk_sum so that the grid reduction in the z dimension - // is placed last, ensuring that we can inline it with downstream tensors. - // - // Epilogue tensors: - // nbatch + 1 2 3 4 5 6 7 8 - // [... Mo No Mwo Nwo Mwi Nwi MNi1 MNi2 MNi3] - // (iS) iBx iBy iTz iTy iS iS iS iTx iV - // - // Before reordering, splitk_sum is similar to mma_result but with the - // reduction axes removed. The Kf dimension causes inlining between nbatch - // + 1 and nbatch + 2. - // nbatch + 1 2 3 4 5 6 7 8 9 - // [... Mo No Kf Mwo Nwo Mwi Nwi MNi1 MNi2 MNi3] - // (iS) iBx iBy rBz iTz iTy iS iS iS iTx iS - // - // splitk_sum (after the reordering below) - // nbatch + 1 2 3 4 5 6 7 8 9 - // [... Mo No Mwo Nwo Mwi Nwi MNi1 MNi2 MNi3 Kf] - // (iS) iBx iBy iTz iTy iS iS iS iTx iS rBz - // - // This reordering step lets us inline all but the last dim MNi3 (position - // nbatch + 7) which might be vectorized. - // - // NOTE: we need to do this reorder after the propagation above so that it - // doesn't get reset. - splitk_sum->reorder({ - {num_batch_dims + 2, num_batch_dims + 9}, - {num_batch_dims + 3, num_batch_dims + 2}, - {num_batch_dims + 4, num_batch_dims + 3}, - {num_batch_dims + 5, num_batch_dims + 4}, - {num_batch_dims + 6, num_batch_dims + 5}, - {num_batch_dims + 7, num_batch_dims + 6}, - {num_batch_dims + 8, num_batch_dims + 7}, - {num_batch_dims + 9, num_batch_dims + 8}, - }); - // Vectorize inner-most dimension - splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize); - } + scheduleSplitKSum(splitk_sum, num_batch_dims, params.use_smem_epilogue); // auto inline for all tensors except register tensors inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb})); @@ -1172,7 +1239,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { if (params.double_buffer_options.double_buffer_smem_read && params.double_buffer_options.double_buffer_smem_write) { - // rotate Ko loop + // rotate Kg loop scheduler_utils::rotateLoop( mma_result, num_batch_dims + 2 + num_splitk_dims, {acr, bcr}); } diff --git a/tests/cpp/test_gpu_tensorcore.cpp b/tests/cpp/test_gpu_tensorcore.cpp index 1f9cadfc765..8b2e0c01a1f 100644 --- a/tests/cpp/test_gpu_tensorcore.cpp +++ b/tests/cpp/test_gpu_tensorcore.cpp @@ -2811,54 +2811,76 @@ TEST_F(GPUTTensorCoreTest, FusionAmpereMatmulSplitK_CUDA) { int M = 504, N = 136, K = 8096; for (auto layout : kAllSupportedMmaLayout) { - Fusion fusion; - FusionGuard fg(&fusion); + for (int splitk_factor : {2}) { + for (int use_smem_epilogue : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout); - auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout); + auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half); + auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half); - auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half); - auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); - fusion.addInput(tv0); - fusion.addInput(tv1); + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); - tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); - tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); - auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + fusion.addOutput(tv2); - fusion.addOutput(tv2); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.splitk_factor = splitk_factor; + if (use_smem_epilogue) { + std::tie( + params.use_smem_epilogue, params.promote_prologue_smem_reuse) = + mma_utils::generateSharedMemoryEpilogueHeuristics( + gemm_tile, + 1, + {DataType::Half, DataType::Half, DataType::Float}, + true, + true, + true); + if (!params.use_smem_epilogue) { + std::cout + << "Skipping smem epilogue due to shared memory constraints on this device" + << std::endl; + continue; + } + params.promote_prologue_smem_reuse = true; + } - MatmulParams params; - params.mma_macro = MmaMacro::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.splitk_factor = 2; - scheduleMatmul(&fusion, params); + scheduleMatmul(&fusion, params); - auto inputs = matmulAtInput3DTuring(M, N, K, layout); + auto inputs = matmulAtInput3DTuring(M, N, K, layout); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + EXPECT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - // Relax tolerance for larger sum due to large K - NVF_CHECK(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + // Relax tolerance for larger sum due to large K + NVF_CHECK(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); - // Check that computed smem matches actually allocated smem - mma_utils::MmaDataTypes data_types = { - DataType::Half, DataType::Half, DataType::Float}; - int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( - params, data_types, true, true); - int64_t actual_smem = fe.lastLaunchParams().smem(); - EXPECT_EQ(estimated_smem, actual_smem); + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); + } + } } } @@ -2873,61 +2895,67 @@ TEST_F(GPUTTensorCoreTest, FusionAmpereMatmulSplitKBias_CUDA) { int M = 504, N = 136, K = 8096; for (auto layout : kAllSupportedMmaLayout) { - Fusion fusion; - FusionGuard fg(&fusion); + for (int splitk_factor : {2}) { + for (int use_smem_epilogue : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout); - auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout); + auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half); + auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half); + auto tv2 = makeContigTensor(1, DataType::Half); - auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half); - auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half); - auto tv2 = makeContigTensor(1, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + auto tv3 = fusedMultiplySum(tv0, tv1, {-1}); + auto tv4 = broadcast(tv2, {false, true}); + auto tv5 = add(tv3, tv4); // bias - tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); - tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); - auto tv3 = fusedMultiplySum(tv0, tv1, {-1}); - auto tv4 = broadcast(tv2, {false, true}); - auto tv5 = add(tv3, tv4); // bias + fusion.addOutput(tv5); - fusion.addOutput(tv5); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.splitk_factor = splitk_factor; + params.use_smem_epilogue = use_smem_epilogue; + params.promote_prologue_smem_reuse = true; - MatmulParams params; - params.mma_macro = MmaMacro::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.splitk_factor = 2; - scheduleMatmul(&fusion, params); + scheduleMatmul(&fusion, params); - auto [aten_a, aten_b] = matmulAtInput3DTuring(M, N, K, layout); - at::Tensor aten_bias = at::randn({M}, aten_a.options()); - std::vector inputs = {aten_a, aten_b, aten_bias}; + auto [aten_a, aten_b] = matmulAtInput3DTuring(M, N, K, layout); + at::Tensor aten_bias = at::randn({M}, aten_a.options()); + std::vector inputs = {aten_a, aten_b, aten_bias}; - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, inputs)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion(inputs); - auto tref = atBiasEpilogue( - atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout), - aten_bias); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, inputs)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion(inputs); + auto tref = atBiasEpilogue( + atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout), + aten_bias); - // Relax tolerance for larger sum due to large K - NVF_CHECK(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + // Relax tolerance for larger sum due to large K + NVF_CHECK(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); - // Check that computed smem matches actually allocated smem - mma_utils::MmaDataTypes data_types = { - DataType::Half, DataType::Half, DataType::Float}; - int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( - params, data_types, true, true); - int64_t actual_smem = fe.lastLaunchParams().smem(); - EXPECT_EQ(estimated_smem, actual_smem); + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, data_types, true, true); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); + } + } } } @@ -2942,55 +2970,68 @@ TEST_F(GPUTTensorCoreTest, FusionAmpereMatmulBatchSplitK_CUDA) { int B = 2, M = 504, N = 136, K = 2048; for (auto layout : kAllSupportedMmaLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(3, DataType::Half); - auto tv1 = makeContigTensor(3, DataType::Half); + for (int splitk_factor : {2}) { + for (int use_smem_epilogue : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(3, DataType::Half); + auto tv1 = makeContigTensor(3, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); - tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); - tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); - auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); - fusion.addOutput(tv2); + fusion.addOutput(tv2); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - MatmulParams params; - params.mma_macro = MmaMacro::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.splitk_factor = 2; - scheduleMatmul(&fusion, params); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.splitk_factor = splitk_factor; + params.use_smem_epilogue = use_smem_epilogue; + params.promote_prologue_smem_reuse = true; - at::Tensor aten_a = - matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - at::Tensor aten_b = - matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + scheduleMatmul(&fusion, params); - std::vector inputs = {aten_a, aten_b}; + at::Tensor aten_a = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + at::Tensor aten_b = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, inputs)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion(inputs); - auto tref = atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout); + std::vector inputs = {aten_a, aten_b}; - // Relax tolerance for larger sum due to large K - EXPECT_TRUE(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, inputs)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion(inputs); + auto tref = + atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout); - // Check that computed smem matches actually allocated smem - mma_utils::MmaDataTypes data_types = { - DataType::Half, DataType::Half, DataType::Float}; - int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( - params, data_types, true, true); - int64_t actual_smem = fe.lastLaunchParams().smem(); - EXPECT_EQ(estimated_smem, actual_smem); + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, + data_types, + // NOTE: Batch split-K matmuls cannot currently re-use smem due to + // outer batch loop + /*smem_a_reuse_guaranteed=*/false, + /*smem_b_reuse_guaranteed=*/false); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); + } + } } } @@ -3005,62 +3046,74 @@ TEST_F(GPUTTensorCoreTest, FusionAmpereMatmulBatchSplitKBias_CUDA) { int B = 2, M = 504, N = 136, K = 2048; for (auto layout : kAllSupportedMmaLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(3, DataType::Half); - auto tv1 = makeContigTensor(3, DataType::Half); - auto tv2 = makeContigTensor(1, DataType::Half); + for (int splitk_factor : {2}) { + for (int use_smem_epilogue : {false, true}) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(3, DataType::Half); + auto tv1 = makeContigTensor(3, DataType::Half); + auto tv2 = makeContigTensor(1, DataType::Half); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); - tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); - tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); - auto tv3 = fusedMultiplySum(tv0, tv1, {-1}); - auto tv4 = broadcast(tv2, {true, false, true}); - auto tv5 = add(tv3, tv4); + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + auto tv3 = fusedMultiplySum(tv0, tv1, {-1}); + auto tv4 = broadcast(tv2, {true, false, true}); + auto tv5 = add(tv3, tv4); - fusion.addOutput(tv5); + fusion.addOutput(tv5); - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 32); - gemm_tile.warp_tile = GemmTile(64, 64, 32); - gemm_tile.instruction_tile = GemmTile(16, 8, 16); + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); - MatmulParams params; - params.mma_macro = MmaMacro::Ampere_16_8_16; - params.tile_sizes = gemm_tile; - params.splitk_factor = 2; - scheduleMatmul(&fusion, params); + MatmulParams params; + params.mma_macro = MmaMacro::Ampere_16_8_16; + params.tile_sizes = gemm_tile; + params.splitk_factor = splitk_factor; + params.use_smem_epilogue = use_smem_epilogue; + params.promote_prologue_smem_reuse = true; - at::Tensor aten_a = - matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); - at::Tensor aten_b = - matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); - at::Tensor aten_bias = at::randn({M}, aten_a.options()); + scheduleMatmul(&fusion, params); - std::vector inputs = {aten_a, aten_b, aten_bias}; + at::Tensor aten_a = + matmulAtInput2D(layout, TensorMatmulPos::A, at::kHalf, M, N, K, B); + at::Tensor aten_b = + matmulAtInput2D(layout, TensorMatmulPos::B, at::kHalf, M, N, K, B); + at::Tensor aten_bias = at::randn({M}, aten_a.options()); - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 7, 5, fe.compileFusion(&fusion, inputs)); - ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); - auto cg_outputs = fe.runFusion(inputs); - auto tref = atBiasEpilogue( - atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout), - aten_bias); + std::vector inputs = {aten_a, aten_b, aten_bias}; - // Relax tolerance for larger sum due to large K - EXPECT_TRUE(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 7, 5, fe.compileFusion(&fusion, inputs)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + auto cg_outputs = fe.runFusion(inputs); + auto tref = atBiasEpilogue( + atMatmul(aten_a.to(at::kFloat), aten_b.to(at::kFloat), layout), + aten_bias); - // Check that computed smem matches actually allocated smem - mma_utils::MmaDataTypes data_types = { - DataType::Half, DataType::Half, DataType::Float}; - int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( - params, data_types, true, true); - int64_t actual_smem = fe.lastLaunchParams().smem(); - EXPECT_EQ(estimated_smem, actual_smem); + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(tref, 1e-6 * K, 1e-6 * K)); + + // Check that computed smem matches actually allocated smem + mma_utils::MmaDataTypes data_types = { + DataType::Half, DataType::Half, DataType::Float}; + int64_t estimated_smem = mma_utils::computeExpectedSharedMemoryUsage( + params, + data_types, + // NOTE: Batch split-K matmuls cannot currently re-use smem due to + // outer batch loop + /*smem_a_reuse_guaranteed=*/false, + /*smem_b_reuse_guaranteed=*/false); + int64_t actual_smem = fe.lastLaunchParams().smem(); + EXPECT_EQ(estimated_smem, actual_smem); + } + } } }