Skip to content

Commit

Permalink
[Hack] Misc plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Sep 8, 2022
1 parent f12b3cc commit 47362b6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
61 changes: 58 additions & 3 deletions benchmarks/cpp/nvfuser/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ size_t getSmemSize(GemmTile cta_tile, int stage_number) {
dataTypeSize(DataType::Half) * stage_number;
}

// This feature is for initial exploration, **not** going to be in any final PR.
int getNumOfStage(){
auto stage = getenv("PYTORCH_NVFUSER_MATMUL_STAGE_NUMBER");
int stage_number = 2;
if(stage){
stage_number = atoi(stage);
}
return stage_number;
}

// TODO: this part eventually will be automated by heuristics
MatmulParam getMatmulParams(
GemmTile cta_tile,
Expand Down Expand Up @@ -245,7 +255,7 @@ static void Nvfuser_Matmul_4warp(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(128, 128, 64);
int number_of_stage = 2;
int number_of_stage = getNumOfStage();

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

Expand All @@ -261,7 +271,39 @@ static void Nvfuser_Matmul_8warp(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(256, 128, 64);
int number_of_stage = 2;
int number_of_stage = getNumOfStage();

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

// Need a +4 on the smem size for nvfuser zero.
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage) + 4, benchmark_state);

// Run benchmark:
SingleMatmulBase(benchmark_state, layout, params);
}

static void Nvfuser_Matmul_4warp32(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(128, 128, 32);
int number_of_stage = getNumOfStage();

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

// Need a +4 on the smem size for nvfuser zero.
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage) + 4, benchmark_state);

// Run benchmark:
SingleMatmulBase(benchmark_state, layout, params);
}

static void Nvfuser_Matmul_8warp32(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(256, 128, 32);
int number_of_stage = getNumOfStage();

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

Expand All @@ -285,7 +327,7 @@ static void Nvfuser_Matmul_8warp(
run(TN, MatmulLayout::TN); \
run(NT, MatmulLayout::NT)

// Instantiations:
// Test specifications:
#define Nvfuser_4warp_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_4warp, no_quant_nvfuser_4warp_##layout_label, layout) \
Expand All @@ -296,11 +338,24 @@ static void Nvfuser_Matmul_8warp(
Nvfuser_Matmul_8warp, no_quant_nvfuser_8warp_##layout_label, layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Nvfuser_4warp32_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_4warp32, no_quant_nvfuser_4warp32_##layout_label, layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Nvfuser_8warp32_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_8warp32, no_quant_nvfuser_8warp32_##layout_label, layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Eagermode_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \
->NO_TILE_QUANTIZATION_ARGS

// Test instantiations:
ForAllLayouts(Nvfuser_4warp_test);
ForAllLayouts(Nvfuser_8warp_test);
ForAllLayouts(Nvfuser_4warp32_test);
ForAllLayouts(Nvfuser_8warp32_test);
ForAllLayouts(Eagermode_test);
15 changes: 3 additions & 12 deletions torch/csrc/jit/codegen/cuda/runtime/helpers.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
#define NVFUSER_DEFINE_MAGIC_ZERO \
__shared__ int nvfuser_zero_s; \
if (threadIdx.x == 0) \
nvfuser_zero_s = 0; \
__syncthreads(); \
atomicMin(&nvfuser_zero_s, threadIdx.x); \
int nvfuser_zero = nvfuser_zero_s;

#define NVFUSER_UPDATE_MAGIC_ZERO \
do { \
nvfuser_zero <<= 1; \
} while (0);
#define NVFUSER_DEFINE_MAGIC_ZERO int nvfuser_zero = 0;

#define NVFUSER_UPDATE_MAGIC_ZERO

__device__ constexpr int ceilDiv(int a, int b) {
return (a + b - 1) / b;
Expand Down

0 comments on commit 47362b6

Please sign in to comment.