Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unswizzle before grid reduction in split-K (#1534)
Serial grid reductions are used in split-K matmuls as of #1510. This means we load and store elements in the reduction tensor according to the indexing of the work buffer. This is unlike ordinary grid reductions that use `gridReduce`, which reduces individual elements using a scheme that ensures coalescing by indexing into the work buffer based on `threadIdx` and `blockIdx`. Currently these split-K accesses are inefficient due to this lack of coalescing. We currently already ensure coalesced output stores in matmuls when possible by using smem for the epilogue (#387). A shared memory buffer is used to communicate elements between threads so that the resulting tensor will have a proper global access pattern when it is written out to global memory as a tile of the output. Before this PR if we used split-K with `use_smem_epilogue = true`, the store to global memory will be coalesced but there will be uncoalesced accesses during the split-K reduction. This PR modifies scheduling so that in those cases, the smem epilogue tensor is placed before the split-K sum, so that unswizzling happens before completing the reduction. The result is that the reduction accesses are coalesced. This is a generated kernel from `NVFuserTest.FusionAmpereMatmulSplitKBias_CUDA`: ```c++ // ... (main loop) ... #pragma unroll for(nvfuser_index_t i59 = 0; i59 < 4LL; ++i59) { nvfuser_index_t i104; i104 = 8LL * i59; nvfuser_index_t i105; i105 = 32LL * i59; #pragma unroll for(nvfuser_index_t i61 = 0; i61 < 8LL; ++i61) { nvfuser_index_t i106; i106 = 4LL * i61; asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" :"=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]), "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]), "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]), "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3]) :"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[0]), "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[1]), "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[2]), "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[3]), "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[0]), "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[1]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]), "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3]) ); } } } NVFUSER_UPDATE_MAGIC_ZERO; __syncthreads(); } __syncthreads(); #pragma unroll for(nvfuser_index_t i107 = 0; i107 < 4LL; ++i107) { nvfuser_index_t i108; i108 = 32LL * i107; nvfuser_index_t i109; i109 = i38 + (2048LL * i107); #pragma unroll for(nvfuser_index_t i110 = 0; i110 < 8LL; ++i110) { nvfuser_index_t i111; i111 = i108 + (4LL * i110); nvfuser_index_t i112; i112 = i11 + i110; nvfuser_index_t i113; i113 = (i109 + (32LL * (i112 / 4LL))) + (8LL * (i39 ^ (i112 % 4LL))); #pragma unroll for(nvfuser_index_t i114 = 0; i114 < 2LL; ++i114) { loadGeneric<float, 2>( &T17[(i113 + (1024LL * i114))], &T16[(i111 + (2LL * i114))]); } } } NVFUSER_UPDATE_MAGIC_ZERO; // Allocate global tensor T19 grid_sync::blockSerializeWait<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); __syncthreads(); #pragma unroll for(nvfuser_index_t i115 = 0; i115 < 32LL; ++i115) { nvfuser_index_t i116; i116 = i115 + nvfuser_zero; nvfuser_index_t i117; i117 = i44 + (i45 * i116); nvfuser_index_t i118; i118 = i47 + (4LL * i115); bool b119; b119 = i55 < (-(4LL * i116)); bool b120; b120 = b54 && b119; Array<float, 4LL, 4> T6; T6.set(float(0.000000000e+00f)); // Allocate global tensor T20 reduction::serialReductionStep</*vec_size=*/4>( &T6[0LL], &T17[(i42 + (512LL * i115))], 0.000000000e+00f, &T20[i117], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1, b120, b120); Array<float, 4LL, 4> T10; #pragma unroll for(nvfuser_index_t i121 = 0; i121 < 4LL; ++i121) { __half T18[1LL]; T18[0LL] = 0LL; if (b119) { T18[0LL] = T2[(i118 + ((i48 + (i121 + nvfuser_zero)) / 128LL))]; } __half T7[1LL]; T7[0LL] = T18[0LL]; float T8[1LL]; T8[0LL] = __half2float(T7[0LL]); T10[i121] = T6[i121] + T8[0LL]; } if ((b56 && b119)) { loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T9[i117], &T10[0LL]); } } NVFUSER_UPDATE_MAGIC_ZERO; grid_sync::blockSerializeRelease<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); } ``` Note that the `i135` loop will be smaller once we have #1528 at which point it would more clearly show reduction followed by the loop for the predicated bias epilogue. (Diff should be viewed hiding whitespace changes as many changes are to indentation).
- Loading branch information