Skip to content

Commit

Permalink
Unswizzle before grid reduction in split-K (#1534)
Browse files Browse the repository at this point in the history
Serial grid reductions are used in split-K matmuls as of #1510. This
means we load and store elements in the reduction tensor according to
the indexing of the work buffer. This is unlike ordinary grid reductions
that use `gridReduce`, which reduces individual elements using a scheme
that ensures coalescing by indexing into the work buffer based on
`threadIdx` and `blockIdx`. Currently these split-K accesses are
inefficient due to this lack of coalescing.

We currently already ensure coalesced output stores in matmuls when
possible by using smem for the epilogue (#387). A shared memory buffer
is used to communicate elements between threads so that the resulting
tensor will have a proper global access pattern when it is written out
to global memory as a tile of the output. Before this PR if we used
split-K with `use_smem_epilogue = true`, the store to global memory will
be coalesced but there will be uncoalesced accesses during the split-K
reduction. This PR modifies scheduling so that in those cases, the smem
epilogue tensor is placed before the split-K sum, so that unswizzling
happens before completing the reduction. The result is that the
reduction accesses are coalesced.

This is a generated kernel from
`NVFuserTest.FusionAmpereMatmulSplitKBias_CUDA`:
```c++
// ... (main loop) ...
     #pragma unroll
      for(nvfuser_index_t i59 = 0; i59 < 4LL; ++i59) {
        nvfuser_index_t i104;
        i104 = 8LL * i59;
        nvfuser_index_t i105;
        i105 = 32LL * i59;
        #pragma unroll
        for(nvfuser_index_t i61 = 0; i61 < 8LL; ++i61) {
          nvfuser_index_t i106;
          i106 = 4LL * i61;
          asm(
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
            :"=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3])
            :"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[0]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[1]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[2]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[3]),
             "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[0]),
             "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[1]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3])
          );
        }
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    __syncthreads();
  }
  __syncthreads();  #pragma unroll
  for(nvfuser_index_t i107 = 0; i107 < 4LL; ++i107) {
    nvfuser_index_t i108;
    i108 = 32LL * i107;
    nvfuser_index_t i109;
    i109 = i38 + (2048LL * i107);
    #pragma unroll
    for(nvfuser_index_t i110 = 0; i110 < 8LL; ++i110) {
      nvfuser_index_t i111;
      i111 = i108 + (4LL * i110);
      nvfuser_index_t i112;
      i112 = i11 + i110;
      nvfuser_index_t i113;
      i113 = (i109 + (32LL * (i112 / 4LL))) + (8LL * (i39 ^ (i112 % 4LL)));
      #pragma unroll
      for(nvfuser_index_t i114 = 0; i114 < 2LL; ++i114) {
        loadGeneric<float, 2>( &T17[(i113 + (1024LL * i114))],  &T16[(i111 + (2LL * i114))]);
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  // Allocate global tensor T19
  grid_sync::blockSerializeWait<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i115 = 0; i115 < 32LL; ++i115) {
    nvfuser_index_t i116;
    i116 = i115 + nvfuser_zero;
    nvfuser_index_t i117;
    i117 = i44 + (i45 * i116);
    nvfuser_index_t i118;
    i118 = i47 + (4LL * i115);
    bool b119;
    b119 = i55 < (-(4LL * i116));
    bool b120;
    b120 = b54 && b119;
    Array<float, 4LL, 4> T6;
    T6.set(float(0.000000000e+00f));
    // Allocate global tensor T20
    reduction::serialReductionStep</*vec_size=*/4>(
      &T6[0LL],
      &T17[(i42 + (512LL * i115))],
      0.000000000e+00f,
      &T20[i117],
      [](float &a, float b) { a = a + b; },
      index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
      index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
      b120,
      b120);
    Array<float, 4LL, 4> T10;
    #pragma unroll
    for(nvfuser_index_t i121 = 0; i121 < 4LL; ++i121) {
      __half T18[1LL];
      T18[0LL] = 0LL;
      if (b119) {
        T18[0LL]
           = T2[(i118 + ((i48 + (i121 + nvfuser_zero)) / 128LL))];
      }
      __half T7[1LL];
      T7[0LL]
         = T18[0LL];
      float T8[1LL];
      T8[0LL]
         = __half2float(T7[0LL]);
      T10[i121]
        = T6[i121]
        + T8[0LL];
    }
    if ((b56 && b119)) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T9[i117], &T10[0LL]);
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeRelease<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
}
```
Note that the `i135` loop will be smaller once we have #1528 at which
point it would more clearly show reduction followed by the loop for the
predicated bias epilogue.

(Diff should be viewed hiding whitespace changes as many changes are to
indentation).
  • Loading branch information
jacobhinkle authored Mar 22, 2024
1 parent f05f9bc commit 7525cc0
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 240 deletions.
50 changes: 41 additions & 9 deletions benchmarks/cpp/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <scheduler/all_schedulers.h>
#include <scheduler/matmul.h>
#include <scheduler/matmul_heuristic.h>
#include <scheduler/mma_utils.h>
#include <utils.h>

#include <benchmark/benchmark.h>
Expand Down Expand Up @@ -145,6 +146,9 @@ static void SingleMatmulBase(
int64_t n = benchmark_state.range(1);
int64_t k = benchmark_state.range(2);

// inputs
at::manual_seed(0);

// Tensor inputs
auto inputs = matmulAtInput2D(m, n, k, layout);
auto expected_output = atMatmul(
Expand All @@ -159,9 +163,6 @@ static void SingleMatmulBase(

preseg_passes::OptimizationPass<preseg_passes::PreSegmenter>::runPass(fusion);

// inputs
at::manual_seed(0);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(
{inputs.first, inputs.second});

Expand Down Expand Up @@ -250,6 +251,14 @@ MatmulParams getMatmulParams(
params.double_buffer_options.double_buffer_smem_read = true;
params.double_buffer_options.smem_double_buffer_stage = stage_number;
params.splitk_factor = splitk_factor;
std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) =
mma_utils::generateSharedMemoryEpilogueHeuristics(
gemm_tile,
stage_number,
{DataType::Half, DataType::Half, DataType::Float},
/*smem_a_reuse_guaranteed=*/true,
/*smem_b_reuse_guaranteed=*/true,
/*ignore_occupancy_drop=*/true);

return params;
}
Expand Down Expand Up @@ -362,7 +371,8 @@ static void NvFuserScheduler_Matmul(
benchmark::State& benchmark_state,
MmaLayout layout,
int splitk_factor = 1,
bool partitionedk = false) {
bool partitionedk = false,
bool use_smem_epilogue = false) {
int num_warps = benchmark_state.range(3);
int number_of_stage = benchmark_state.range(4);

Expand All @@ -376,6 +386,15 @@ static void NvFuserScheduler_Matmul(

auto params = getMatmulParams(
cta_tile, number_of_stage, layout, partitionedk ? 1 : splitk_factor);
if (use_smem_epilogue) {
if (!params.use_smem_epilogue) {
benchmark_state.SkipWithError(
"Insufficient shared mem for smem epilogue");
}
} else {
params.use_smem_epilogue = false;
params.promote_prologue_smem_reuse = false;
}

NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
Expand Down Expand Up @@ -630,13 +649,23 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) {
// Use this for manual splitk.
static void MatmulShapeWarpStageSpecificSplitK(
benchmark::internal::Benchmark* b) {
b->ArgNames({"M", "N", "K", "warps", "stages", "splitk_factor"});
b->ArgNames(
{"M", "N", "K", "warps", "stages", "splitk_factor", "smem_epilogue"});
for (long int num_warps : NumWarps) {
for (long int num_stages : NumStages) {
for (auto [m, n, k] :
std::vector<std::tuple<int, int, int>>(SplitKSpecificShapes)) {
for (auto splitk_factor : {2, 3, 4, 5, 6}) {
b->Args({m, n, k, num_warps, num_stages, splitk_factor});
for (bool use_smem_epilogue : {false, true}) {
b->Args(
{m,
n,
k,
num_warps,
num_stages,
splitk_factor,
use_smem_epilogue});
}
}
}
}
Expand Down Expand Up @@ -727,8 +756,13 @@ static void NvFuserScheduler_Matmul_Manual(
benchmark::State& benchmark_state,
MmaLayout layout) {
int splitk_factor = benchmark_state.range(5);
bool use_smem_epilogue = benchmark_state.range(6);
NvFuserScheduler_Matmul(
benchmark_state, layout, splitk_factor, /*partitionedk=*/false);
benchmark_state,
layout,
splitk_factor,
/*partitionedk=*/false,
use_smem_epilogue);
}

#define SpecificSplitKBenchmark(layout) \
Expand All @@ -742,9 +776,7 @@ static void NvFuserScheduler_Matmul_Manual(

ForAllLayouts(EagerModeBenchmark);
ForAllLayouts(NvfuserMatmulBenchmark);
ForAllLayouts(AutoSplitKBenchmark);
ForAllLayouts(SpecificSplitKBenchmark);
ForAllLayouts(AutoPartitionedKBenchmark);

// Note: SplitK Reduction benchmarks are parametrized only by M, N. The splitk
// factor is deduced automatically from N
Expand Down
14 changes: 12 additions & 2 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,16 @@ class VectorizeValidator : public OptInDispatch {
if (r_id->isReduction() || r_id->isBroadcast()) {
continue;
}
if ((tv->getMemoryType() == MemoryType::Shared ||
tv->getMemoryType() == MemoryType::Local) &&
r_id->isBlockDim()) {
// Inner-most parallelized dimensions don't count in allocation of
// shared and local tensors.
continue;
}
if (tv->getMemoryType() == MemoryType::Local && r_id->isThreadDim()) {
continue;
}
last_alloc_dim = r_id;
last_alloc_dim_pos = i - 1;
break;
Expand Down Expand Up @@ -416,9 +426,9 @@ class VectorizeValidator : public OptInDispatch {
", allocation domain: ",
ir_utils::toString(tv->getMaybeAllocationDomain()),
", vectorized id: ",
validator.vectorized_id_,
validator.vectorized_id_->toString(),
", innermost id: ",
last_alloc_dim,
last_alloc_dim->toString(),
", contiguity: ",
contiguity.has_value() ? (*contiguity ? "t" : "f") : "n");
}
Expand Down
Loading

0 comments on commit 7525cc0

Please sign in to comment.