Skip to content

Commit

Permalink
Fix smem swizzle for matmul (#588)
Browse files Browse the repository at this point in the history
In #387, we have to do a
```C++
    int64_t swizzle_period =
        std::gcd(n_rows / repeated_pattern_size, tile_size_y / n_cols);
```
in order to make our swizzling algorithm work for epilogue. This looks
more like an empirical hack whose only goal is to creates a square
block. Although it empirically worked, I struggled to find a
first-principle explanation for this approach. So I read through my
original PR #155 multiple times and
think through things carefully. But the more I read and think, the more
I feel that the original implementation in
#155 does not make sense. The
problem is, #155 tries to interleave
the entire `ldmatrix_rows / repeated_pattern_size` with an equal size
split on tile y dimension. This is overkill, because we just need to
evenly distribute rows on different megabanks, and as long as we do so,
the number of rows can be arbitrarily large and we can still be
bank-conflict free. So we should be swizzling on a `(g, g)` block
instead of a (potentially much larger) `(ldmatrix_rows /
repeated_pattern_size, ldmatrix_rows / repeated_pattern_size)` block.
  • Loading branch information
zasdfgbnm authored Jul 13, 2023
1 parent 9713d78 commit 2e4c1e5
Showing 1 changed file with 68 additions and 28 deletions.
96 changes: 68 additions & 28 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,40 +349,80 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) {
* 6| | |
* 7| | |
* +----------+----------+
*
* We can consider each repeated_pattern_size rows as a gigarow, and each
* repeated_pattern_size megabanks as a gigabank. Note that megabank is a
* contiguous chunk of banks, but gigabank is not contiguous. Indeed,
* nearby megabanks in a gigabank has a distance of `g` megabanks
*/

TORCH_INTERNAL_ASSERT(
ldmatrix_rows % repeated_pattern_size == 0,
"Can not partition matrix into megarows");
int64_t num_gigarows = ldmatrix_rows / repeated_pattern_size;
int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size

// -2 -1
// [row, col]
TORCH_INTERNAL_ASSERT(
tile_size_x % ldmatrix_rows == 0, "Partial matrices not supported");
shared_mem_tv->split(-2, ldmatrix_rows);
TORCH_INTERNAL_ASSERT(
tile_size_y % ldmatrix_cols == 0, "Partial matrices not supported");
shared_mem_tv->split(-2, repeated_pattern_size);
shared_mem_tv->split(-1, ldmatrix_cols);
// -4 -3 -2 -1
// [matrix id, matrix, matrix id, matrix]
TORCH_INTERNAL_ASSERT(
ldmatrix_rows % repeated_pattern_size == 0,
"ldmatrix_rows is assumed to be a multiple of repeated_pattern_size");
shared_mem_tv->split(-3, repeated_pattern_size);
// -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id, matrix]
int64_t swizzle_period = ldmatrix_rows / repeated_pattern_size;
// -4 -3 -2 -1
// [gigarow id, gigarow, matrix id, matrix]
shared_mem_tv->split(-2, num_gigabanks);
// -5 -4 -3 -2 -1
// [gigarow id, gigarow, y outer, gigabank id, matrix]
// Note that megabanks inside a gigabank are not contiguous, so the gigabank
// id is -2 instead of -3

/* We want to evenly distribute gigarows across gigabanks, for example, if
* we have 7 gigarows and 3 gigabanks, then we might distribute them as:
* +---+
* |x |
* | x |
* | x|
* |x |
* | x |
* | x|
* |x |
* +---+
* considering all matrices, this is a swizzle function like:
* +---+
* |012|
* |201|
* |120|
* |012|
* |201|
* |120|
* |012|
* +---+
* which is a cyclic shift.
*
* Note that because num_gigabanks (a.k.a. g) divide num_megabanks and
* row_stride_znz (which is row_stride % num_megabanks), g should also
* divide row_stride, because according to the fundamental
* division-with-remainder property (see comment in expr_simplifier.h):
* row_stride = q * num_megabanks + row_stride_znz
* which means, we can just consider each num_gigabanks matrices as a group,
* and we always have complete groups (i.e. no group has less than
* num_gigabanks matrices). Interleaving the memory of matrices within each
* group should be enough to fully remove bank conflict.
*/

/* To further simplify the problem, if we assume: */
TORCH_INTERNAL_ASSERT(
tile_size_y % (swizzle_period * ldmatrix_cols) == 0,
"need aperiodic swizzle config for tile size ",
tile_size_x,
"x",
tile_size_y,
"with units ",
ldmatrix_rows,
"x",
ldmatrix_cols);
shared_mem_tv->split(-2, swizzle_period);
// -6 -5 -4 -3 -2 -1
// [matrix id, repeat, pattern, matrix id outer, pattern id, matrix]
// swizzle repeat with pattern id to make repeat no longer repeat
if (isPowOf2(swizzle_period)) {
num_gigarows % num_gigabanks == 0,
"Requires non-square swizzle, which is not supported yet");
/* Then we can partition gigarows into full waves, each wave has
* num_gigabanks gigarows. This partition creates square dimensions, making
* the swizzle implementation easier */

// -5 -4 -3 -2 -1
// [gigarow id, gigarow, y outer, gigabank id, matrix]
shared_mem_tv->split(-5, num_gigabanks);
// -6 -5 -4 -3 -2 -1
// [wave id, wave, gigarow, y outer, gigabank id, matrix]

if (isPowOf2(num_gigabanks)) {
shared_mem_tv->swizzle(Swizzle2DType::XOR, -5, -2);
} else {
shared_mem_tv->swizzle(Swizzle2DType::CyclicShift, -5, -2);
Expand Down

0 comments on commit 2e4c1e5

Please sign in to comment.