Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add epilogue to store MMA results in shared memory before write to #387

Merged
merged 36 commits into from
Jul 17, 2023

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented May 22, 2023

  • add epilogue to do MMA Results --> Shared Memory --> Global Memory
  • Remove Bank conflicts
  • Reuse Shared Memory

Leave Reuse Shared Memory to another PR.

@liqiangxl liqiangxl force-pushed the llu/matmul_epilogue branch 2 times, most recently from 6b6402c to a8639d4 Compare May 23, 2023 18:12
@liqiangxl liqiangxl force-pushed the llu/matmul_epilogue branch 2 times, most recently from e0f91bf to b73f307 Compare June 6, 2023 22:05
@liqiangxl liqiangxl force-pushed the llu/matmul_epilogue branch from b73f307 to 9fb5a54 Compare June 8, 2023 19:59
@liqiangxl liqiangxl requested a review from zasdfgbnm June 8, 2023 21:19
@liqiangxl liqiangxl marked this pull request as ready for review June 8, 2023 21:19
csrc/scheduler/matmul_heuristic.h Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
test/test_gpu_tensorcore.cpp Outdated Show resolved Hide resolved
@liqiangxl liqiangxl requested review from zasdfgbnm and drzejan2 June 20, 2023 18:38
@zasdfgbnm
Copy link
Collaborator

Could you add a new variant of MatmulSASSTest.AmpereModifiers_CUDA for the case where there is an epilogue unswizzle?

csrc/scheduler/matmul_heuristic.h Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
@zasdfgbnm
Copy link
Collaborator

Also, could you:

  1. Set has_smem_epilogue to true and try our cpp benchmarks to see the results? (Do they succeed? Are they faster?)
  2. Post the generated kernel here for review?

@mmigdal-nv mmigdal-nv self-requested a review June 23, 2023 08:20
@liqiangxl liqiangxl force-pushed the llu/matmul_epilogue branch 2 times, most recently from 81b0761 to 30e0791 Compare June 23, 2023 20:44
@liqiangxl
Copy link
Collaborator Author

Update:
(1) For some cases with large CTA tiles, there are not enough shared memory. Will set has_smem_epilogue to False.
(2) All other test and benchmark cases passed test.
(3) Many regressions in benchmark. Working on figure out why performance dropped.

@liqiangxl
Copy link
Collaborator Author

The main reason for performance regresson is:
Due to the increase of shared memory usage, the CTAs per SM is reduced from 2 to 1.
Left part is ncu results with shared memory epilogue, right part is without epilogue.
image

I'll add shared meomory reuse then test performance again.

@zasdfgbnm
Copy link
Collaborator

It's totally up to you to decide how to proceed, but I would suggest polishing this PR first and merge it. As long as we are not enabling the smem epilog by default, it should be fine. The schedule here makes sense, I just have some general software engineering suggestions on how to generate this schedule to make it more maintainable in the long term. I have this suggestion because likely there will be other changes in the schedule, like split-k, and that surely will introduce some conflict with this PR, merging this PR early could reduce our effort on resolving conflicts.

@liqiangxl
Copy link
Collaborator Author

Make sense. Then, I'll tidy up this PR first.

@liqiangxl liqiangxl force-pushed the llu/matmul_epilogue branch from 37f591e to fad4ad8 Compare June 28, 2023 00:32
device_smem_limit / (smem_a + smem_b + smem_c), (size_t)blocks_per_sm);
return blocks_per_sm_with_smem_epilogue ==
blocks_per_sm_without_smem_epilogue ||
blocks_per_sm_with_smem_epilogue > 0;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|| blocks_per_sm_with_smem_epilogue > 0; is a tmp condition to allow the test of as many cases as possbile without considering performance. Will be deleted before merge.

@liqiangxl
Copy link
Collaborator Author

!build

csrc/scheduler/mma_utils.cpp Show resolved Hide resolved
csrc/scheduler/mma_utils.cpp Outdated Show resolved Hide resolved
csrc/scheduler/mma_utils.cpp Show resolved Hide resolved
csrc/scheduler/matmul_utils.cpp Show resolved Hide resolved
dataTypeSize(DataType::Float);
const size_t smem_size = smem_a + smem_b + smem_c;

return smem_size <= device_smem_limit;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @liqiangxl for running verification smem-based and register-based occupancy. The updated implementation looks good.

@@ -525,6 +656,10 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
// Mma object is valid only because cacheBefore has been done on
// TV which is not output of MmaOp, as there is an epilogue
auto mma_result = has_epilogue ? mma->out()->as<TensorView>() : cc;
// epilogue shared memory tensor if use shared memory epilogue
// mma_result -> cc (if has_epilogue) -> c_smem -> c
auto c_smem = params.has_smem_epilogue ? c->cacheBefore() : c;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick question, the follow up task for this PR is to enable reuse shared memory?

The tricky part is that we calculate smem usage with the output type, so in case of HSH kernel we will do calculations for fp16 type, maybe this needs to be changed.

We should discuss this on the next matmul meeting.

@liqiangxl
Copy link
Collaborator Author

A quick question, the follow up task for this PR is to enable reuse shared memory?

yes.

// auto inline for all tensors except register tensors and output tensor
inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb, d}));
// auto inline for all tensors except register tensors
inlineMost(ir_utils::allTvsExcept(fusion, {acr, bcr, ab, bb}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a conditional modifications, based on params.has_smem_epilogue?

Something like this:

  // auto inline for all tensors except register tensors
  std::vector<TensorView*> inlining_exceptions = {acr, bcr, ab, bb};
  if (!params.has_smem_epilogue) {
    // with no smemu epilogue, we should also exclude output tensor
    inlining_exceptions.push_back(d);
  }
  inlineMost(ir_utils::allTvsExcept(fusion, inlining_exceptions));

Copy link
Collaborator Author

@liqiangxl liqiangxl Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the special treatment of the output tensor is not needed. And this change is trying to make the code clean. I'll reverse this change and leave it to another PR.

// -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id, matrix]
int64_t swizzle_period = ldmatrix_rows / repeated_pattern_size;
// [matrix id, repeat, pattern, matrix id, matrix, ***** skip ***** ]
TORCH_INTERNAL_ASSERT(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the above:

    int64_t swizzle_period =
        std::gcd(n_rows / repeated_pattern_size, tile_size_y / n_cols);

this check is now redundant and should be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

Comment on lines 386 to 388
// update repeated_pattern_size to ensure n_rows / repeated_pattern_size
// equals to swizzle_period, this is required by 2D swizzle.
repeated_pattern_size = n_rows / swizzle_period;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the gcd above, swizzle_period is either equal or smaller than n_rows / repeated_pattern_size, I don't think we should change repeated_pattern_size, because by definition for i = 0...repeated_pattern_size, different i values never conflict. Instead, we should check that, if swizzle_period < n_rows / repeated_pattern_size, we split n_rows as 3d [outer, swizzle_period, repeated_pattern_size] and swizzle the swizzle_period with matrix id. In the future, we might want to just merge the tile_x/n_rows with tile_y/n_cols and use that for swizzle, but let's leave it in future PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the following code to split n_rows as 3d :

    // further split n_rows to [outer, swizzle_period, repeated_pattern_size]
    if (swizzle_period < n_rows / repeated_pattern_size) {
      const int repeat_axis = repeated_pattern_size > 1 ? -4 - skip : -3 - skip;
      shared_mem_tv->split(repeat_axis, swizzle_period);
      //     -6      -5             -4              -3       -2         -1
      // [matrix id, nrows_outer, swizzle_period, pattern, matrix id, matrix,
      // ***** skip ***** ]
    }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with this change, the bank conflict can be removed in cases previously can't be fully removed.

// The first para in std::gcd is from the above derivation.
// The second para is from the fact that we need to split tile_size_y by
// n_cols and then by swizzle_period. If swizzle_period is smaller than the
// first para, then the bank conflict can't be fully removed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then the bank conflict can't be fully removed.

Can you open an issue for that? We need to look deep into this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The statement is no longer valid. Corrected code comments.

Comment on lines 385 to 416
int64_t swizzle_period =
std::gcd(n_rows / repeated_pattern_size, tile_size_y / n_cols);

// -2 -1
// [row, col]
// [row, col, ***** skip ***** ]
TORCH_INTERNAL_ASSERT(
tile_size_x % ldmatrix_rows == 0, "Partial matrices not supported");
shared_mem_tv->split(-2, ldmatrix_rows);
tile_size_x % n_rows == 0, "Partial matrices not supported");
shared_mem_tv->split(-2 - skip, n_rows);
TORCH_INTERNAL_ASSERT(
tile_size_y % ldmatrix_cols == 0, "Partial matrices not supported");
shared_mem_tv->split(-1, ldmatrix_cols);
tile_size_y % n_cols == 0, "Partial matrices not supported");
shared_mem_tv->split(-1 - skip, n_cols);
// -4 -3 -2 -1
// [matrix id, matrix, matrix id, matrix]
// [matrix id, matrix, matrix id, matrix, ***** skip ***** ]
TORCH_INTERNAL_ASSERT(
ldmatrix_rows % repeated_pattern_size == 0,
"ldmatrix_rows is assumed to be a multiple of repeated_pattern_size");
shared_mem_tv->split(-3, repeated_pattern_size);
n_rows % repeated_pattern_size == 0,
"n_rows is assumed to be a multiple of repeated_pattern_size");
if (repeated_pattern_size > 1) {
shared_mem_tv->split(-3 - skip, repeated_pattern_size);
}
// -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id, matrix]
int64_t swizzle_period = ldmatrix_rows / repeated_pattern_size;
TORCH_INTERNAL_ASSERT(
tile_size_y % (swizzle_period * ldmatrix_cols) == 0,
"need aperiodic swizzle config for tile size ",
tile_size_x,
"x",
tile_size_y,
"with units ",
ldmatrix_rows,
"x",
ldmatrix_cols);
shared_mem_tv->split(-2, swizzle_period);
// -6 -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id outer, pattern id, matrix]
// swizzle repeat with pattern id to make repeat no longer repeat
if (isPowOf2(swizzle_period)) {
shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2);
} else {
shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2);
// [matrix id, repeat, pattern, matrix id, matrix, ***** skip ***** ]

// further split n_rows to [outer, swizzle_period, repeated_pattern_size]
if (swizzle_period < n_rows / repeated_pattern_size) {
const int repeat_axis = repeated_pattern_size > 1 ? -4 - skip : -3 - skip;
shared_mem_tv->split(repeat_axis, swizzle_period);
// -6 -5 -4 -3 -2 -1
// [matrix id, nrows_outer, swizzle_period, pattern, matrix id, matrix,
// ***** skip ***** ]
}

shared_mem_tv->split(-2 - skip, swizzle_period);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part should be simplified with #588

zasdfgbnm added a commit that referenced this pull request Jul 13, 2023
In #387, we have to do a
```C++
    int64_t swizzle_period =
        std::gcd(n_rows / repeated_pattern_size, tile_size_y / n_cols);
```
in order to make our swizzling algorithm work for epilogue. This looks
more like an empirical hack whose only goal is to creates a square
block. Although it empirically worked, I struggled to find a
first-principle explanation for this approach. So I read through my
original PR #155 multiple times and
think through things carefully. But the more I read and think, the more
I feel that the original implementation in
#155 does not make sense. The
problem is, #155 tries to interleave
the entire `ldmatrix_rows / repeated_pattern_size` with an equal size
split on tile y dimension. This is overkill, because we just need to
evenly distribute rows on different megabanks, and as long as we do so,
the number of rows can be arbitrarily large and we can still be
bank-conflict free. So we should be swizzling on a `(g, g)` block
instead of a (potentially much larger) `(ldmatrix_rows /
repeated_pattern_size, ldmatrix_rows / repeated_pattern_size)` block.
@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly good to me, left some minor comments.

// specific case. We should remove this special handling when we fix our CA
// mapping.
if (shared_mem_tv->getMemoryType() == MemoryType::Shared) {
int swizzle_axis0 = repeated_pattern_size > 1 ? -5 : -4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Just use axis_of_gigarow_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed.

csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/matmul.cpp Outdated Show resolved Hide resolved
csrc/scheduler/mma_utils.cpp Outdated Show resolved Hide resolved
csrc/scheduler/mma_utils.cpp Show resolved Hide resolved
MatmulParams params;
params.mma_macro = MmaOptions::MacroType::Ampere_16_8_16;
params.tile_sizes = gemm_tile;
params.has_smem_epilogue = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be true?

Comment on lines 418 to 421
// Disable shared memory epilogue before shared memory reuse is implemented.
// Otherwise, there will be performance regression due to reduced occupancy
// caused by extra shared memory usage.
constexpr bool allow_smem_epilogue = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still true with the new hasEnoughSharedMemoryForEpilogue that takes register occupancy into consideration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case FusionAmpereMatmulSmemEpilogue_CUDA, The settings of cta_tile (64, 128, 32), warp_tile(32, 32, 32), and smem_double_buffer_stage(2) are purposefully changed to produce a constant occupancy of 25% no matter has_smem_epilogue = true or false. This allows me to evaluate the influence of the has_smem_epilogue. There is no noticable change of performance for TT, NT, TN (kernel1-3). But for NN (kernel4), there is a slightly decrease. I checked the global memory write index, the access is coalesced for has_smem_epilogue = true and strided for has_smem_epilogue = false

has_smem_epilogue = false
smem_false_64_128_32.log:kernel1 run in 1.19603 ms, achieved: 112.219 GB/s
smem_false_64_128_32.log:kernel2 run in 1.18682 ms, achieved: 113.091 GB/s
smem_false_64_128_32.log:kernel3 run in 1.00045 ms, achieved: 134.158 GB/s
smem_false_64_128_32.log:kernel4 run in 1.16224 ms, achieved: 115.482 GB/s

has_smem_epilogue = true
smem_true_64_128_32.log:kernel1 run in 1.19296 ms, achieved: 112.508 GB/s
smem_true_64_128_32.log:kernel2 run in 1.19091 ms, achieved: 112.702 GB/s
smem_true_64_128_32.log:kernel3 run in 0.995328 ms, achieved: 134.848 GB/s
smem_true_64_128_32.log:kernel4 run in 1.18067 ms, achieved: 113.679 GB/s

Global Memory write index for MatmulLayout::NN with has_smem_epilogue = false:

smem_false4_64_128_32_write_index
tidyz=00 tidx=0 from=0 to=0
tidyz=00 tidx=1 from=0 to=2
tidyz=00 tidx=2 from=0 to=4
tidyz=00 tidx=3 from=0 to=6
tidyz=00 tidx=4 from=0 to=4096
tidyz=00 tidx=5 from=0 to=4098
tidyz=00 tidx=6 from=0 to=4100
tidyz=00 tidx=7 from=0 to=4102
tidyz=00 tidx=8 from=0 to=8192
tidyz=00 tidx=9 from=0 to=8194
tidyz=00 tidx=10 from=0 to=8196
tidyz=00 tidx=11 from=0 to=8198
tidyz=00 tidx=12 from=0 to=12288
tidyz=00 tidx=13 from=0 to=12290
tidyz=00 tidx=14 from=0 to=12292
tidyz=00 tidx=15 from=0 to=12294
tidyz=00 tidx=16 from=0 to=16384
tidyz=00 tidx=17 from=0 to=16386
tidyz=00 tidx=18 from=0 to=16388
tidyz=00 tidx=19 from=0 to=16390
tidyz=00 tidx=20 from=0 to=20480
tidyz=00 tidx=21 from=0 to=20482
tidyz=00 tidx=22 from=0 to=20484
tidyz=00 tidx=23 from=0 to=20486
tidyz=00 tidx=24 from=0 to=24576
tidyz=00 tidx=25 from=0 to=24578
tidyz=00 tidx=26 from=0 to=24580
tidyz=00 tidx=27 from=0 to=24582
tidyz=00 tidx=28 from=0 to=28672
tidyz=00 tidx=29 from=0 to=28674
tidyz=00 tidx=30 from=0 to=28676
tidyz=00 tidx=31 from=0 to=28678

Global Memory write index for MatmulLayout::NN with has_smem_epilogue = true:

smem_true4_64_128_32_write_index
tidyz=00 tidx=0 from=0 to=0
tidyz=00 tidx=1 from=4 to=4
tidyz=00 tidx=2 from=8 to=8
tidyz=00 tidx=3 from=12 to=12
tidyz=00 tidx=4 from=16 to=16
tidyz=00 tidx=5 from=20 to=20
tidyz=00 tidx=6 from=24 to=24
tidyz=00 tidx=7 from=28 to=28
tidyz=00 tidx=8 from=32 to=32
tidyz=00 tidx=9 from=36 to=36
tidyz=00 tidx=10 from=40 to=40
tidyz=00 tidx=11 from=44 to=44
tidyz=00 tidx=12 from=48 to=48
tidyz=00 tidx=13 from=52 to=52
tidyz=00 tidx=14 from=56 to=56
tidyz=00 tidx=15 from=60 to=60
tidyz=00 tidx=16 from=64 to=64
tidyz=00 tidx=17 from=68 to=68
tidyz=00 tidx=18 from=72 to=72
tidyz=00 tidx=19 from=76 to=76
tidyz=00 tidx=20 from=80 to=80
tidyz=00 tidx=21 from=84 to=84
tidyz=00 tidx=22 from=88 to=88
tidyz=00 tidx=23 from=92 to=92
tidyz=00 tidx=24 from=96 to=96
tidyz=00 tidx=25 from=100 to=100
tidyz=00 tidx=26 from=104 to=104
tidyz=00 tidx=27 from=108 to=108
tidyz=00 tidx=28 from=112 to=112
tidyz=00 tidx=29 from=116 to=116
tidyz=00 tidx=30 from=120 to=120
tidyz=00 tidx=31 from=124 to=124


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so changing has_smem_epilogue to true makes the memory access contiguous, but the perf drop? That is strange.

// a thread can use up to 255 registers, blocks per sm is limited by available
// registers
const auto threads_per_sm = getThreadsPerSMGivenRegPerThread(255);
const auto blocks_per_sm = threads_per_sm / threads_per_block;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to blocks_per_sm_by_register

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, could you evaluate the perf using the tool @drzejan2 and @mmigdal-nv was using and update the slides? This perf eval will take a while, so let's merge first then eval perf.

@mmigdal-nv
Copy link
Collaborator

mmigdal-nv commented Jul 17, 2023

LGTM now, could you evaluate the perf using the tool @drzejan2 and @mmigdal-nv was using and update the slides? This perf eval will take a while, so let's merge first then eval perf.

I can run this if you want. Should take a few hours. @liqiangxl what is the extra expected smem usage related to this pull request? sizeof(ComputeType) * tileM * tileN ?

@liqiangxl
Copy link
Collaborator Author

LGTM now, could you evaluate the perf using the tool @drzejan2 and @mmigdal-nv was using and update the slides? This perf eval will take a while, so let's merge first then eval perf.

I can run this if you want. Should take a few hours. @liqiangxl what is the extra expected smem usage related to this pull request? sizeof(ComputeType) * tileM * tileN ?

Thanks. Yes, the extra smem usage is sizeof(ComputeType) * tileM * tileN. I think in most of the test cases, generateSharedMemoryEpilogueHeuristics will return false.

@liqiangxl liqiangxl merged commit 3d3f405 into main Jul 17, 2023
@liqiangxl liqiangxl deleted the llu/matmul_epilogue branch July 17, 2023 17:17
jacobhinkle added a commit that referenced this pull request Mar 22, 2024
Serial grid reductions are used in split-K matmuls as of #1510. This
means we load and store elements in the reduction tensor according to
the indexing of the work buffer. This is unlike ordinary grid reductions
that use `gridReduce`, which reduces individual elements using a scheme
that ensures coalescing by indexing into the work buffer based on
`threadIdx` and `blockIdx`. Currently these split-K accesses are
inefficient due to this lack of coalescing.

We currently already ensure coalesced output stores in matmuls when
possible by using smem for the epilogue (#387). A shared memory buffer
is used to communicate elements between threads so that the resulting
tensor will have a proper global access pattern when it is written out
to global memory as a tile of the output. Before this PR if we used
split-K with `use_smem_epilogue = true`, the store to global memory will
be coalesced but there will be uncoalesced accesses during the split-K
reduction. This PR modifies scheduling so that in those cases, the smem
epilogue tensor is placed before the split-K sum, so that unswizzling
happens before completing the reduction. The result is that the
reduction accesses are coalesced.

This is a generated kernel from
`NVFuserTest.FusionAmpereMatmulSplitKBias_CUDA`:
```c++
// ... (main loop) ...
     #pragma unroll
      for(nvfuser_index_t i59 = 0; i59 < 4LL; ++i59) {
        nvfuser_index_t i104;
        i104 = 8LL * i59;
        nvfuser_index_t i105;
        i105 = 32LL * i59;
        #pragma unroll
        for(nvfuser_index_t i61 = 0; i61 < 8LL; ++i61) {
          nvfuser_index_t i106;
          i106 = 4LL * i61;
          asm(
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
            :"=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3])
            :"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[0]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[1]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[2]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[3]),
             "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[0]),
             "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[1]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3])
          );
        }
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    __syncthreads();
  }
  __syncthreads();  #pragma unroll
  for(nvfuser_index_t i107 = 0; i107 < 4LL; ++i107) {
    nvfuser_index_t i108;
    i108 = 32LL * i107;
    nvfuser_index_t i109;
    i109 = i38 + (2048LL * i107);
    #pragma unroll
    for(nvfuser_index_t i110 = 0; i110 < 8LL; ++i110) {
      nvfuser_index_t i111;
      i111 = i108 + (4LL * i110);
      nvfuser_index_t i112;
      i112 = i11 + i110;
      nvfuser_index_t i113;
      i113 = (i109 + (32LL * (i112 / 4LL))) + (8LL * (i39 ^ (i112 % 4LL)));
      #pragma unroll
      for(nvfuser_index_t i114 = 0; i114 < 2LL; ++i114) {
        loadGeneric<float, 2>( &T17[(i113 + (1024LL * i114))],  &T16[(i111 + (2LL * i114))]);
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  // Allocate global tensor T19
  grid_sync::blockSerializeWait<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i115 = 0; i115 < 32LL; ++i115) {
    nvfuser_index_t i116;
    i116 = i115 + nvfuser_zero;
    nvfuser_index_t i117;
    i117 = i44 + (i45 * i116);
    nvfuser_index_t i118;
    i118 = i47 + (4LL * i115);
    bool b119;
    b119 = i55 < (-(4LL * i116));
    bool b120;
    b120 = b54 && b119;
    Array<float, 4LL, 4> T6;
    T6.set(float(0.000000000e+00f));
    // Allocate global tensor T20
    reduction::serialReductionStep</*vec_size=*/4>(
      &T6[0LL],
      &T17[(i42 + (512LL * i115))],
      0.000000000e+00f,
      &T20[i117],
      [](float &a, float b) { a = a + b; },
      index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
      index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
      b120,
      b120);
    Array<float, 4LL, 4> T10;
    #pragma unroll
    for(nvfuser_index_t i121 = 0; i121 < 4LL; ++i121) {
      __half T18[1LL];
      T18[0LL] = 0LL;
      if (b119) {
        T18[0LL]
           = T2[(i118 + ((i48 + (i121 + nvfuser_zero)) / 128LL))];
      }
      __half T7[1LL];
      T7[0LL]
         = T18[0LL];
      float T8[1LL];
      T8[0LL]
         = __half2float(T7[0LL]);
      T10[i121]
        = T6[i121]
        + T8[0LL];
    }
    if ((b56 && b119)) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T9[i117], &T10[0LL]);
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeRelease<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
}
```
Note that the `i135` loop will be smaller once we have #1528 at which
point it would more clearly show reduction followed by the loop for the
predicated bias epilogue.

(Diff should be viewed hiding whitespace changes as many changes are to
indentation).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants