diff --git a/benchmarks/cpp/nvfuser/matmul.cpp b/benchmarks/cpp/nvfuser/matmul.cpp index 02ba113264740c..455722cebed568 100644 --- a/benchmarks/cpp/nvfuser/matmul.cpp +++ b/benchmarks/cpp/nvfuser/matmul.cpp @@ -215,6 +215,16 @@ size_t getSmemSize(GemmTile cta_tile, int stage_number) { dataTypeSize(DataType::Half) * stage_number; } +// This feature is for initial exploration, **not** going to be in any final PR. +int getNumOfStage(){ + auto stage = getenv("PYTORCH_NVFUSER_MATMUL_STAGE_NUMBER"); + int stage_number = 2; + if(stage){ + stage_number = atoi(stage); + } + return stage_number; +} + // TODO: this part eventually will be automated by heuristics MatmulParam getMatmulParams( GemmTile cta_tile, @@ -245,7 +255,7 @@ static void Nvfuser_Matmul_4warp( benchmark::State& benchmark_state, MatmulLayout layout) { auto cta_tile = GemmTile(128, 128, 64); - int number_of_stage = 2; + int number_of_stage = getNumOfStage(); auto params = getMatmulParams(cta_tile, number_of_stage, layout); @@ -261,7 +271,39 @@ static void Nvfuser_Matmul_8warp( benchmark::State& benchmark_state, MatmulLayout layout) { auto cta_tile = GemmTile(256, 128, 64); - int number_of_stage = 2; + int number_of_stage = getNumOfStage(); + + auto params = getMatmulParams(cta_tile, number_of_stage, layout); + + // Need a +4 on the smem size for nvfuser zero. + NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( + 8, 0, getSmemSize(cta_tile, number_of_stage) + 4, benchmark_state); + + // Run benchmark: + SingleMatmulBase(benchmark_state, layout, params); +} + +static void Nvfuser_Matmul_4warp32( + benchmark::State& benchmark_state, + MatmulLayout layout) { + auto cta_tile = GemmTile(128, 128, 32); + int number_of_stage = getNumOfStage(); + + auto params = getMatmulParams(cta_tile, number_of_stage, layout); + + // Need a +4 on the smem size for nvfuser zero. + NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( + 8, 0, getSmemSize(cta_tile, number_of_stage) + 4, benchmark_state); + + // Run benchmark: + SingleMatmulBase(benchmark_state, layout, params); +} + +static void Nvfuser_Matmul_8warp32( + benchmark::State& benchmark_state, + MatmulLayout layout) { + auto cta_tile = GemmTile(256, 128, 32); + int number_of_stage = getNumOfStage(); auto params = getMatmulParams(cta_tile, number_of_stage, layout); @@ -285,7 +327,7 @@ static void Nvfuser_Matmul_8warp( run(TN, MatmulLayout::TN); \ run(NT, MatmulLayout::NT) -// Instantiations: +// Test specifications: #define Nvfuser_4warp_test(layout_label, layout) \ BENCHMARK_CAPTURE( \ Nvfuser_Matmul_4warp, no_quant_nvfuser_4warp_##layout_label, layout) \ @@ -296,11 +338,24 @@ static void Nvfuser_Matmul_8warp( Nvfuser_Matmul_8warp, no_quant_nvfuser_8warp_##layout_label, layout) \ ->NO_TILE_QUANTIZATION_ARGS +#define Nvfuser_4warp32_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + Nvfuser_Matmul_4warp32, no_quant_nvfuser_4warp32_##layout_label, layout) \ + ->NO_TILE_QUANTIZATION_ARGS + +#define Nvfuser_8warp32_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + Nvfuser_Matmul_8warp32, no_quant_nvfuser_8warp32_##layout_label, layout) \ + ->NO_TILE_QUANTIZATION_ARGS + #define Eagermode_test(layout_label, layout) \ BENCHMARK_CAPTURE( \ EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \ ->NO_TILE_QUANTIZATION_ARGS +// Test instantiations: ForAllLayouts(Nvfuser_4warp_test); ForAllLayouts(Nvfuser_8warp_test); +ForAllLayouts(Nvfuser_4warp32_test); +ForAllLayouts(Nvfuser_8warp32_test); ForAllLayouts(Eagermode_test); \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu index 6502e7b8c4d7b7..c440157581aef7 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -1,15 +1,6 @@ -#define NVFUSER_DEFINE_MAGIC_ZERO \ - __shared__ int nvfuser_zero_s; \ - if (threadIdx.x == 0) \ - nvfuser_zero_s = 0; \ - __syncthreads(); \ - atomicMin(&nvfuser_zero_s, threadIdx.x); \ - int nvfuser_zero = nvfuser_zero_s; - -#define NVFUSER_UPDATE_MAGIC_ZERO \ - do { \ - nvfuser_zero <<= 1; \ - } while (0); +#define NVFUSER_DEFINE_MAGIC_ZERO int nvfuser_zero = 0; + +#define NVFUSER_UPDATE_MAGIC_ZERO __device__ constexpr int ceilDiv(int a, int b) { return (a + b - 1) / b;