From 7b37a836bc0e55ae17f38d7b92974dd6ef380718 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 15 Feb 2023 12:00:27 -0800 Subject: [PATCH] Enable grid outer persistent scheduling (#2435) * Enable grid outer persistent scheduling --- third_party/nvfuser/CMakeLists.txt | 1 + third_party/nvfuser/benchmark/utils.cpp | 5 +- .../nvfuser/csrc/scheduler/debug_utils.h | 27 + .../nvfuser/csrc/scheduler/normalization.cpp | 176 +++- .../csrc/scheduler/normalization_utils.cpp | 504 ++++++++++ .../csrc/scheduler/normalization_utils.h | 155 ++++ .../nvfuser/csrc/scheduler/reduction.cpp | 6 +- .../csrc/scheduler/reduction_heuristic.h | 53 +- .../csrc/scheduler/reduction_utils.cpp | 151 ++- .../nvfuser/csrc/scheduler/registry.cpp | 212 ++++- third_party/nvfuser/csrc/scheduler/utils.cpp | 4 +- third_party/nvfuser/csrc/scheduler/utils.h | 2 +- third_party/nvfuser/csrc/utils.cpp | 1 + third_party/nvfuser/csrc/utils.h | 1 + .../nvfuser/test/test_gpu_outer_reduction.cpp | 861 ++++++++++++++++++ 15 files changed, 2093 insertions(+), 66 deletions(-) create mode 100644 third_party/nvfuser/csrc/scheduler/normalization_utils.cpp create mode 100644 third_party/nvfuser/csrc/scheduler/normalization_utils.h diff --git a/third_party/nvfuser/CMakeLists.txt b/third_party/nvfuser/CMakeLists.txt index 88839f1e90ea6..3bb29a34b3b1e 100644 --- a/third_party/nvfuser/CMakeLists.txt +++ b/third_party/nvfuser/CMakeLists.txt @@ -111,6 +111,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp ${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp diff --git a/third_party/nvfuser/benchmark/utils.cpp b/third_party/nvfuser/benchmark/utils.cpp index 51aa1d67cf533..2d266e121d6e2 100644 --- a/third_party/nvfuser/benchmark/utils.cpp +++ b/third_party/nvfuser/benchmark/utils.cpp @@ -29,7 +29,10 @@ std::string toString(const ReductionParams& rparams) { ss << " // Iteration Domain: " << (rparams.multiple_reds_per_blk ? "multiple reductions per block / " : "") - << (rparams.split_grid_dim_iter_dom ? "split grid dimension / " : "") + << ((rparams.split_grid_dim_iter_dom_inner || + rparams.split_grid_dim_iter_dom_outer) + ? "split grid dimension / " + : "") << (rparams.vectorize_iter_dom ? "vectorize / " : "") << (rparams.unroll_factor_iter_dom > 1 && !rparams.vectorize_iter_dom ? "unroll / " diff --git a/third_party/nvfuser/csrc/scheduler/debug_utils.h b/third_party/nvfuser/csrc/scheduler/debug_utils.h index c059eea19cd64..460313efb391c 100644 --- a/third_party/nvfuser/csrc/scheduler/debug_utils.h +++ b/third_party/nvfuser/csrc/scheduler/debug_utils.h @@ -1,5 +1,9 @@ #pragma once +#include + +#include + namespace torch { namespace jit { namespace fuser { @@ -26,6 +30,29 @@ void canScheduleRejectReason(HeuristicType heuristic, const Args&... args) { "Scheduler _", heuristic, "_ ***rejected*** because : ", args...); } +// Based on +// https://learn.microsoft.com/en-us/cpp/cpp/ellipses-and-variadic-templates?view=msvc-170#example +inline void log() { + if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) { + std::cerr << std::endl; + } +} + +template +void log(const T& t) { + if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) { + std::cerr << t << std::endl; + } +} + +template +void log(const First& first, const Rest&... rest) { + if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) { + std::cerr << first; + log(rest...); + } +} + } // namespace scheduler_debug_utils } // namespace cuda diff --git a/third_party/nvfuser/csrc/scheduler/normalization.cpp b/third_party/nvfuser/csrc/scheduler/normalization.cpp index cf48d2eb92851..7da19c0c43764 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization.cpp @@ -1,10 +1,12 @@ #include #include +#include #include #include #include #include +#include #include #include #include @@ -13,6 +15,8 @@ #include +#include + namespace torch { namespace jit { namespace fuser { @@ -516,7 +520,7 @@ std::shared_ptr innerPersistentHeuristic( if (godim > 1) { rparams->grid_dim_iter_dom = ParallelType::BIDx; if (godim > scheduler_utils::x_grid_limit) { - rparams->split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom_outer = true; gdimx = scheduler_utils::x_grid_limit; } } @@ -566,6 +570,77 @@ std::shared_ptr innerPersistentHeuristic( return rparams; } +// Heuristics for grid outer normalizations +std::shared_ptr gridOuterPersistentHeuristic( + const int64_t total_reduction_numel, + const int64_t total_iteration_numel, + const int64_t n_tensor_inputs, + const int64_t max_input_dtype_size, + const int64_t max_persistent_buffer_size, + const size_t vectorize_factor) { + auto outer_params = + normalization_scheduler_utils::getGridOuterNormalizationParams( + total_reduction_numel, + total_iteration_numel, + vectorize_factor, + max_persistent_buffer_size); + + TORCH_INTERNAL_ASSERT(outer_params.has_value(), "No valid config found"); + + const auto pb_size = outer_params->persistent_buffer_factor; + const auto unswitch_factor = outer_params->unswitch_factor; + + auto rparams = std::make_shared(); + + rparams->persistent_kernel = true; + rparams->cross_block_inner_reduction = true; + rparams->cross_grid_inner_reduction = true; + rparams->grid_dim_iter_dom = ParallelType::BIDx; + rparams->grid_dim_inner_reduction = ParallelType::BIDy; + rparams->block_dim_inner_reduction = ParallelType::TIDy; + rparams->batches_per_block_inner_reduction = pb_size; + rparams->multiple_reds_per_blk = true; + rparams->vectorize_iter_dom = true; + rparams->unroll_factor_iter_dom = vectorize_factor; + rparams->block_dim_iter_dom = ParallelType::TIDx; + rparams->unroll_factor_inner_reduction = unswitch_factor; + rparams->split_grid_dim_iter_dom_inner = + ceilDiv( + total_iteration_numel / vectorize_factor, + outer_params->launch_params.bdimx()) > + outer_params->launch_params.gdimx(); + rparams->compute_persistent_buffer_with_first_consumer = true; + rparams->static_bdimx = true; + rparams->static_bdimy = true; + + rparams->lparams = LaunchParams( + rparams->split_grid_dim_iter_dom_inner + ? outer_params->launch_params.gdimx() + : LaunchParams::UNINITIALIZED_VAL, + LaunchParams::UNINITIALIZED_VAL, + LaunchParams::UNINITIALIZED_VAL, + outer_params->launch_params.bdimx(), + outer_params->launch_params.bdimy(), + LaunchParams::UNINITIALIZED_VAL); + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + std::cerr << "\n===== Reduction Stats ========\n" + << "total_reduction_numel: " << total_reduction_numel << "\n" + << "total_iteration_numel: " << total_iteration_numel << "\n" + << "vectorize_factor: " << vectorize_factor << "\n" + << "n_tensor_inputs: " << n_tensor_inputs << "\n" + << "max_input_dtype_size: " << max_input_dtype_size << "\n" + << "max_persistent_buffer_size: " << max_persistent_buffer_size + << "\n" + << "persistent_buffer_factor: " << pb_size << "\n" + << "block(" << outer_params->launch_params.bdimx() << ", " + << outer_params->launch_params.bdimy() << ", 1)" << std::endl; + std::cerr << rparams->toString() << std::endl; + } + + return rparams; +} + // Copied from reduction scheduler, should generalize. Simply needed to take out // grid reductions. // TODO: Check adding iteration domain unrolling @@ -587,13 +662,6 @@ std::shared_ptr outerPersistentHeuristic( const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - auto const max_unroll = ceilDiv( - // Available unrolling based on size of data type - (int64_t)16 / (int64_t)max_input_dtype_size, - // Reduce unrolling if we have many inputs, start reduction at 4 inputs - scheduler_utils::lastPow2( - std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1))); - // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set // minimum warp as 16 threads instead of 32 as if we have a small reduction // dim going a bit smaller than 32 usually helps. @@ -602,16 +670,44 @@ std::shared_ptr outerPersistentHeuristic( ? (int64_t)32 / max_input_dtype_size : 16; - // Initialization + const auto register_file_size = + at::cuda::getCurrentDeviceProperties()->regsPerBlock * sizeof(int); + + // Each block runs N reductions, where N is defined as: + // vectorize_factor * blockDim.x. The minimum number of SMs to run + // this as a persistent kernel is thus defined as: + const int64_t min_required_sm_per_norm = ceilDiv( + max_persistent_buffer_size * vectorize_factor * + normalization_scheduler_utils::PreferredLaunchConfig::kMinBdimx, + register_file_size); + + if (min_required_sm_per_norm > 1) { + return gridOuterPersistentHeuristic( + total_reduction_numel, + total_iteration_numel, + n_tensor_inputs, + max_input_dtype_size, + max_persistent_buffer_size, + vectorize_factor); + } + int64_t target_blocks = 1; int64_t target_unroll = 1; int64_t max_threads_in_block = warp_size; - // If we have one warp per block, check if that's enough to saturate the SMs. - // Blocks can't come out of reduction dimension, so only use iteration - // dimension here. + // If we have one warp per block, check if that's enough to saturate the + // SMs. Blocks can't come out of reduction dimension, so only use + // iteration dimension here. target_blocks = ceilDiv(total_iteration_numel, (int64_t)warp_size); + const auto max_unroll = ceilDiv( + // Available unrolling based on size of data type + (int64_t)16 / (int64_t)max_input_dtype_size, + // Reduce unrolling if we have many inputs, start reduction at 4 + // inputs + scheduler_utils::lastPow2( + scheduler_utils::safeDiv((int64_t)n_tensor_inputs, 4))); + // If we have more than a wave of blocks, put parallelism into unrolling if (target_blocks > device_multiprocessor_count) { target_unroll = std::min( @@ -636,10 +732,8 @@ std::shared_ptr outerPersistentHeuristic( // Compute maximum number of reductions we could do in the same kernel based // on persistent buffer size - - const int64_t max_multi_reduction_factor = std::max( - scheduler_utils::register_file_size / max_persistent_buffer_size, - (int64_t)1); + const int64_t max_multi_reduction_factor = scheduler_utils::safeDiv( + scheduler_utils::register_file_size, max_persistent_buffer_size); // To get to target threads: // Prioritize @@ -652,16 +746,11 @@ std::shared_ptr outerPersistentHeuristic( // (2) y dim in multiple reductions - need to flip unrolling to reduction // domain for this - // Blocks for outputs - // int64_t gdimx = 1; // unused at this time, comment for clang tidy - // Threads for reduction int64_t bdimy = 1; // Threads for output int64_t bdimx = 1; - int64_t gdimx = 1; - // Unroll amount int64_t inner_reduction_unroll_factor = 1; int64_t iter_unroll_factor = 1; @@ -673,12 +762,12 @@ std::shared_ptr outerPersistentHeuristic( bdimx = scheduler_utils::lastPow2(bdimx); } - // Prioritie unrolling on iteration domain, but don't sacrifice occupancy, + // Prioritize unrolling on iteration domain, but don't sacrifice occupancy, // make sure there is at least one wave. if (ceilDiv(total_iteration_numel, bdimx) > 2 * device_multiprocessor_count) { iter_unroll_factor = std::min( std::min( - std::max(max_multi_reduction_factor / bdimx, (int64_t)1), + scheduler_utils::safeDiv(max_multi_reduction_factor, bdimx), max_unroll), ceilDiv(device_multiprocessor_count, bdimx)); } @@ -690,10 +779,10 @@ std::shared_ptr outerPersistentHeuristic( // Put more into bdimx bdimx = std::min( std::min( - std::max( + scheduler_utils::safeDiv( // Don't exceed multi reduction factor - max_multi_reduction_factor / iter_unroll_factor, - (int64_t)1), + max_multi_reduction_factor, + iter_unroll_factor), // Leave a full wave of blocks ceilDiv( total_iteration_numel, @@ -711,7 +800,7 @@ std::shared_ptr outerPersistentHeuristic( // Fill bdimy with left over threads bdimy = std::min( - std::max(max_threads_in_block / bdimx, (int64_t)1), + scheduler_utils::safeDiv(max_threads_in_block, bdimx), total_reduction_numel); bool vectorize = false; @@ -724,6 +813,15 @@ std::shared_ptr outerPersistentHeuristic( (int64_t)vectorize_factor); } + int64_t sm_required_per_norm_set = ceilDiv( + max_persistent_buffer_size * bdimx * iter_unroll_factor, + scheduler_utils::register_file_size); + + TORCH_INTERNAL_ASSERT( + sm_required_per_norm_set == 1, + "Tried to use multiple SMs on an outer persistent kernel ", + "yet this kernel should have been within block persistent."); + // Since this is persistent and registers will have to be used anyways unroll // the reduction dim if it's available inner_reduction_unroll_factor = @@ -751,7 +849,7 @@ std::shared_ptr outerPersistentHeuristic( batches_per_block != roundUpPow2Or8(batches_per_block / 2)) { batches_per_block = roundUpPow2Or8(batches_per_block / 2); - // Adjust bdimx based on batches_per_block and unroll factor set + // Adjust bdimy based on batches_per_block and unroll factor set bdimy = ceilDiv( total_reduction_numel, inner_reduction_unroll_factor * batches_per_block); @@ -775,7 +873,7 @@ std::shared_ptr outerPersistentHeuristic( bdimx = ceilDiv(bdimx, 2); } - gdimx = ceilDiv(total_iteration_numel, bdimx); + int gdimx = ceilDiv(total_iteration_numel, bdimx); auto rparams = std::make_shared(); rparams->batches_per_block_inner_reduction = batches_per_block; @@ -791,7 +889,8 @@ std::shared_ptr outerPersistentHeuristic( } rparams->grid_dim_iter_dom = ParallelType::BIDx; - rparams->split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; + rparams->split_grid_dim_iter_dom_outer = + gdimx > scheduler_utils::x_grid_limit; if (rparams->block_dim_iter_dom == ParallelType::TIDx) { rparams->block_dim_inner_reduction = ParallelType::TIDy; @@ -1052,8 +1151,6 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { // fusion segmentation scheduler_utils::clearMemorySpace(fusion); - auto persistent_info = scheduler_utils::persistentBuffers(fusion); - auto reduction_tvs = scheduler_utils::getReductionTvs(fusion); TORCH_INTERNAL_ASSERT(reduction_tvs.size()); @@ -1072,6 +1169,11 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { scheduler_utils::domainReorderAsRfactorMap(reduction_tv)); } + if (rparams.persistent_kernel && rparams.cross_grid_inner_reduction && + !rparams.fastest_dim && reduction_tvs.size() > 1) { + groupReductions(reduction_tvs, false); + } + auto dim_analysis = scheduler_utils::canonicalDimReduction( fusion, reduction_tv, rparams.fastest_dim && rparams.schedule_3D); bool has_iter_axis = dim_analysis.first; @@ -1099,6 +1201,7 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { for (auto output : dummy_outputs) { fusion->addOutput(output); } + reduction_scheduler_utils::multiReductionInliner( fusion, rparams, @@ -1108,6 +1211,15 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { cached_inputs, cached_outputs, dummy_outputs); + + if (rparams.compute_persistent_buffer_with_first_consumer) { + TORCH_INTERNAL_ASSERT( + rparams.persistent_kernel, + "computeWith should be only used with persistent kernels"); + for (const auto persistent_buffer : cached_inputs) { + persistent_buffer->computeWith(-1, true); + } + } } } // namespace cuda diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp new file mode 100644 index 0000000000000..b7ad9905594b8 --- /dev/null +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -0,0 +1,504 @@ +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace normalization_scheduler_utils { + +using scheduler_debug_utils::log; + +PreferredLaunchConfig::PreferredLaunchConfig() : valid_(true) { + initValidGdims(); + resetBdim(); + resetGdim(); +} + +bool PreferredLaunchConfig::isNextSmallerBdimx() const { + return grid_dims_pos_ + 1 == (int)valid_grid_dims_.size(); +} + +bool PreferredLaunchConfig::canLowerBdimx() const { + return bdimx() > kMinBdimx; +} + +bool PreferredLaunchConfig::setBdimx(int bdimx, bool dry_run) { + constexpr int block_size = 256; + + if (bdimx < kMinBdimx || bdimx > kMaxBdimx) { + return false; + } + + TORCH_INTERNAL_ASSERT(block_size % bdimx == 0, "Invalid bdimx: ", bdimx); + int bdimy = block_size / bdimx; + + if (!dry_run) { + bdimy_ = bdimy; + bdimx_ = bdimx; + } + + return true; +} + +// Populate the list of valid gridDim configs for persistent grid +// normalization kernels in the order of increasing gridDim.y. +// Start +// with gridDim.y == 2. For example, on A100, the list would be: [(54, +// 2), (36, 3), (27, 4), (21, 5), (18, 6), (15, 7), (13, 8), (12, 9), +// (10, 10), (9, 12), (8, 13), (7, 15), (6, 18), (5, 21), (4, 27), (3, +// 36), (2, 54)]. +void PreferredLaunchConfig::initValidGdims() { + std::vector> grid_dims; + const int num_sms = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + const int max_first_half = + static_cast(std::sqrt(static_cast(num_sms))); + for (int gdimy = 2; gdimy <= max_first_half; ++gdimy) { + int gdimx = num_sms / gdimy; + grid_dims.push_back(std::make_pair(gdimx, gdimy)); + } + // Reverse the first half and swap gridDim.x and gridDim.y. That + // list becomes the latter half + auto latter_half = grid_dims; + std::reverse(latter_half.begin(), latter_half.end()); + for (const auto& gdimx_gdimy : latter_half) { + if (gdimx_gdimy.second == gdimx_gdimy.first) { + // This is already in the first half + continue; + } + grid_dims.push_back(std::make_pair(gdimx_gdimy.second, gdimx_gdimy.first)); + } + valid_grid_dims_ = grid_dims; +} + +bool PreferredLaunchConfig::moveToNextConfig() { + if (moveToNextGdim()) { + return true; + } + + // Can't increase gdimy. Try bdimy next. + if (moveToNextBdim()) { + return true; + } + + // No more valid config + invalidate(); + return false; +} + +bool PreferredLaunchConfig::moveToNextBdim() { + const int new_bdimx = bdimx() / 2; + if (setBdimx(new_bdimx)) { + resetGdim(); + return true; + } else { + invalidate(); + return false; + } +} + +bool PreferredLaunchConfig::moveToNextGdim() { + auto grid_dims_next_pos = getNextGdimsPos(); + if (grid_dims_next_pos >= 0) { + grid_dims_pos_ = grid_dims_next_pos; + return true; + } else { + return false; + } +} + +int PreferredLaunchConfig::peekNextGdimx() const { + auto grid_dims_next_pos = getNextGdimsPos(); + if (grid_dims_next_pos >= 0) { + return gdimxAt(grid_dims_next_pos); + } else { + return -1; + } +} + +int PreferredLaunchConfig::peekNextGdimy() const { + auto grid_dims_next_pos = getNextGdimsPos(); + if (grid_dims_next_pos >= 0) { + return gdimyAt(grid_dims_next_pos); + } else { + return -1; + } +} + +int PreferredLaunchConfig::getNextGdimsPos() const { + auto grid_dims_next_pos = grid_dims_pos_ + 1; + if (grid_dims_next_pos < (int)valid_grid_dims_.size()) { + return grid_dims_next_pos; + } else { + return -1; + } +} + +namespace { + +// Estimated register count available for persistent buffer. The +// available space is considered to depend on the size of the +// persistent buffer itself due to the predicate caching +int64_t getAvailableRegisterCount(int64_t persistent_buffer_factor) { + // The thread block size is (currently) always 256, so each thread + // can use up to 255 registers + int64_t register_count = 255; + + // Offset a constant overhead + register_count -= 40; + + // Allow small number of spills + register_count += 5; + + // account for index caching, assuming each cache entry + // consumes one register + // TODO: Consider relaxing this reduction. It seems likes + // overestimation. + register_count -= persistent_buffer_factor; + + return register_count; +} + +int64_t getMinPersistentBufferSize( + const int64_t total_reduction_numel, + const int64_t bdimy, + const int64_t gdimy) { + return ceilDiv(ceilDiv(total_reduction_numel, bdimy), gdimy); +} + +// Return true if a given combination of parameters is likely to +// result in no (or little) register spilling +bool checkIfWithinRegisterSpace( + int64_t total_reduction_numel, + int64_t persistent_buffer_size, + int64_t vectorize_factor, + int64_t bdimy, + int64_t gdimy) { + // The extent of the persistent buffer domain + auto pb_factor = + getMinPersistentBufferSize(total_reduction_numel, bdimy, gdimy); + + TORCH_INTERNAL_ASSERT(pb_factor > 0); + + const auto available_reg_count = getAvailableRegisterCount(pb_factor); + + auto per_thread_persistent_buffer_size = + ceilDiv(ceilDiv(persistent_buffer_size, bdimy), gdimy) * vectorize_factor; + + auto persistent_buffer_reg_count = + ceilDiv(per_thread_persistent_buffer_size, sizeof(int)); + + log("persistent_buffer_reg_count: ", + persistent_buffer_reg_count, + ", available_reg_count: ", + available_reg_count); + + return persistent_buffer_reg_count <= available_reg_count; +} + +// Calculate the factor of work of the last thread block in each of +// reductions. More specifically, use the number of serial +// iterations for the persistent buffer loop as a proxy of the +// amount of work. The rest of the blocks should execute the loop +// buffer_size times, whereas the last block only processes the +// remaining iterations. +double getLastBlockWorkRatio( + const int64_t total_reduction_numel, + const int64_t bdimy, + const int64_t persistent_buffer_size) { + auto last_block_pb = + total_reduction_numel % (persistent_buffer_size * bdimy) / bdimy; + return ((double)last_block_pb) / (double)persistent_buffer_size; +}; + +// In the current outer normalization scheduling, only the last thread +// block of each reduction group hits the fallback path of the +// unswitched loops, so it can be significantly slower than the +// rest. This is particularly problematic with grid persistence as all +// thread blocks need to synchronize, so the slowest block determines +// the performance. This could be to some extent mitigated by +// adjusting the buffer size such that the work assigned to the last +// block is relatively smaller than the work assigned to the +// rest. +// +// Here, given a valid launch config, we try to slightly adjust it so +// that the ratio of the last work becomes the smallest. We do this by +// increasing buffer sizes and in turn decreasing gdimy and picking the +// configuration that has the smallest work ratio. All of this is done +// with some bounds, e.g., the buffer size should still be within the +// register space, the decrease of gdimy should be less than 10%, +// etc. These threshold values are experimentally picked on A100 with +// the current benchmarks, but more tuning would likely lead to better +// performance. +// +// The function returns the adjusted gdimy and persistent buffer size +// as well as a bool indicating whether the work size is +// sufficiently reduced. Nullopt is returned if no adjustment is +// successfully done and the search should continue. +std::optional> reduceWorkOfLastBlock( + const PreferredLaunchConfig& launch_cfg, + const int64_t total_reduction_numel, + const int64_t total_iteration_numel, + const int64_t persistent_buffer_size, + const int64_t vectorize_factor) { + const auto bdimy = launch_cfg.bdimy(); + + // Aim to reduce the work size of the last block to be smaller than + // some factor of the rest of the blocks. + const double target_last_block_work_ratio = 0.25; + + // Start with the current gdimy and buffer size. Gradually increase + // the buffer size and in turn decrease gdimy with the bounds set as + // below. + auto current_gdimy = launch_cfg.gdimy(); + auto current_buffer_size = + getMinPersistentBufferSize(total_reduction_numel, bdimy, current_gdimy); + + log("reduceWorkOfLastBlock: ", current_gdimy, ", ", current_buffer_size); + + // Threshold to stop decreasing gdimy + const auto min_gdimy = current_gdimy * 0.9; + + // Keep track of the best gdimy and buffer size configuration + auto optimal_size = current_buffer_size; + auto optimal_gdimy = current_gdimy; + double optimal_work_ratio = + getLastBlockWorkRatio(total_reduction_numel, bdimy, current_buffer_size); + + // Find the best gdimy and buffer size configuration by lowering + // gdimy. Stop if the minimum gdimy is hit or the register limit is + // reached. + while (current_gdimy >= min_gdimy && + checkIfWithinRegisterSpace( + total_reduction_numel, + persistent_buffer_size, + vectorize_factor, + bdimy, + current_gdimy)) { + auto ratio_of_last_block_work = getLastBlockWorkRatio( + total_reduction_numel, bdimy, current_buffer_size); + log("Ratio of last block work: ", + ratio_of_last_block_work, + ", persistent_buffer: ", + current_buffer_size, + ", gdimy: ", + current_gdimy); + + if (ratio_of_last_block_work < optimal_work_ratio) { + optimal_work_ratio = ratio_of_last_block_work; + optimal_size = current_buffer_size; + optimal_gdimy = current_gdimy; + } + + if (ratio_of_last_block_work < target_last_block_work_ratio) { + // Good enough config found; stop searching + break; + } + + // not good enough; increase persistent buffer + ++current_buffer_size; + // adjust gdimy (decreased as persitent_buffer is increased) + current_gdimy = + ceilDiv(ceilDiv(total_reduction_numel, bdimy), current_buffer_size); + + log("Next buffer size: ", + current_buffer_size, + ", Next gdimy: ", + current_gdimy); + } + + // Use the optimal ratio if it's within the threshold + if (optimal_work_ratio < target_last_block_work_ratio) { + log("Successfully reduced to ", optimal_work_ratio); + return std::make_tuple(optimal_gdimy, optimal_size, true); + } + + // Acceptable config not found. Continue searching a better config + // by moving to the next candidate. However, if the next candidate + // incurs a larger number of grid syncs, i.e., the serial factor of + // the iteration domain is larger, the additional overheaad would + // likely to outweight the benefit of potentially better block + // specialization, so pick the best among found so far. + auto next_gdimx = launch_cfg.peekNextGdimx(); + + // If the next gdimx is negative, that means there's no more config + // candidate or the next would decrease the bdimx, which could be a + // large perf degradation, so stop the search then. + if (next_gdimx < 0) { + log("Stop as there's no more search space left for gdimx"); + return std::make_tuple(optimal_gdimy, optimal_size, false); + } + + if (next_gdimx > 0) { + auto remaining_iteration_factor = ceilDiv( + ceilDiv(total_iteration_numel, vectorize_factor), launch_cfg.bdimx()); + auto current_iterration_count = + ceilDiv(remaining_iteration_factor, launch_cfg.gdimx()); + auto next_iteration_count = ceilDiv(remaining_iteration_factor, next_gdimx); + log("Next iteration count: ", + next_iteration_count, + ", next gdimx: ", + next_gdimx, + ", current iteration: ", + current_iterration_count, + ", curreng gdimx: ", + launch_cfg.gdimx()); + if (next_iteration_count > current_iterration_count) { + log("Still not good but stop here to avoid increase of iteration count"); + return std::make_tuple(optimal_gdimy, optimal_size, false); + } + } + + log("Acceptable config not found. Continue search"); + return std::nullopt; +} + +} // namespace + +// Iterate configurations from largest blockDim.x and smallest +// gridDim.y until the per-thread size of the persistent buffer +// becomes sufficiently small enough not to cause (significant) +// register spill. +std::optional getGridOuterNormalizationParams( + int64_t total_reduction_numel, + int64_t total_iteration_numel, + int64_t vectorize_factor, + int64_t persistent_buffer_size) { + PreferredLaunchConfig launch_cfg; + + // The launch config starts with the largest blockDim.x, which may + // be larger than the iteration size. Decrease it until it doesn't + // exceed the iteration size. + const auto max_bdimx = ceilDiv(total_iteration_numel, vectorize_factor); + while (launch_cfg.bdimx() > max_bdimx) { + if (!launch_cfg.moveToNextBdim()) { + // The iteration size is too small. It might still be worthwhile + // to be persistent, but it's unlikely to be performant anyway + return std::nullopt; + } + } + + // Iterate candidates of launch configurations + while (!launch_cfg.isInvalid()) { + log("Current config: ", launch_cfg); + + // Skip if iterations are not evenly distributed among thread + // blocks unless the remaining factor is smaller than + // gridDim.x. However, don't skip if this is the last valid config + // within the same blockDim config. + auto remaining_gdimx_factor = + ceilDiv(total_iteration_numel / vectorize_factor, launch_cfg.bdimx()); + // TODO: Needs better tuning. Probably want to allow + // configurations that are slightly uneven + if (remaining_gdimx_factor > launch_cfg.gdimx() && + remaining_gdimx_factor % launch_cfg.gdimx() != 0 && + !launch_cfg.isNextSmallerBdimx()) { + log("Rejected due to uneven iteration domain"); + launch_cfg.moveToNextConfig(); + continue; + } + + if (!checkIfWithinRegisterSpace( + total_reduction_numel, + persistent_buffer_size, + vectorize_factor, + launch_cfg.bdimy(), + launch_cfg.gdimy())) { + log("Rejected due to register spill"); + launch_cfg.moveToNextConfig(); + continue; + } + + // At this point, gdimy is large enough to keep the register + // pressure low enough. + + // In case the iteration domain is small, the gdimx and bdimx pair + // may be too large and some threads/blocks may be idle. + + if (remaining_gdimx_factor < launch_cfg.gdimx()) { + log("gdimx too large: ", + remaining_gdimx_factor, + ", vec: ", + vectorize_factor); + launch_cfg.moveToNextConfig(); + continue; + } + + // If there's idle tidx threads, don't accept if there's further + // config candidates with smaller bdimx + if (vectorize_factor * launch_cfg.bdimx() * launch_cfg.gdimx() > + total_iteration_numel && + launch_cfg.canLowerBdimx()) { + log("Skip due to too large bdimx: ", launch_cfg.bdimx()); + launch_cfg.moveToNextBdim(); + continue; + } + + // Adjust gdimy and buffer size for processing predicates more + // efficiently through the block specialization, so that the last + // block is assigned with a relatively small chunk of work. + // For some reason, this doesn't work well on Titan RTX. It seems + // it's just better unswitching by a small factor. + // TODO: Test other generations of GPUs + int64_t adjusted_gdimy = -1; + int64_t adjusted_buffer_size = -1; + bool last_block_work_reduced = false; + const auto major_ver = at::cuda::getCurrentDeviceProperties()->major; + const auto minor_ver = at::cuda::getCurrentDeviceProperties()->minor; + if (major_ver == 7 && minor_ver == 5) { + adjusted_gdimy = launch_cfg.gdimy(); + adjusted_buffer_size = getMinPersistentBufferSize( + total_reduction_numel, launch_cfg.bdimy(), launch_cfg.gdimy()); + last_block_work_reduced = false; + } else { + auto gdimy_pb_size = reduceWorkOfLastBlock( + launch_cfg, + total_reduction_numel, + total_iteration_numel, + persistent_buffer_size, + vectorize_factor); + if (!gdimy_pb_size.has_value()) { + launch_cfg.moveToNextConfig(); + continue; + } + std::tie(adjusted_gdimy, adjusted_buffer_size, last_block_work_reduced) = + *gdimy_pb_size; + } + + // Acceptable configuration found + auto launch_params = LaunchParams( + launch_cfg.gdimx(), + adjusted_gdimy, + LaunchParams::UNINITIALIZED_VAL, + launch_cfg.bdimx(), + launch_cfg.bdimy(), + LaunchParams::UNINITIALIZED_VAL); + + // If the last block is sufficiently reduced, unswitch the whole + // persistent buffer. Otherwise, unswitch by a factor of 4. + int64_t unswitch_factor = last_block_work_reduced + ? adjusted_buffer_size + : std::min(4l, adjusted_buffer_size); + + GridOuterNormalizationParams params = { + .launch_params = launch_params, + .persistent_buffer_factor = adjusted_buffer_size, + .unswitch_factor = unswitch_factor}; + return params; + } + + // No valid config found. Return launch_cfg, which should be marked + // as invalid + TORCH_INTERNAL_ASSERT(launch_cfg.isInvalid()); + return std::nullopt; +} + +} // namespace normalization_scheduler_utils +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.h b/third_party/nvfuser/csrc/scheduler/normalization_utils.h new file mode 100644 index 0000000000000..6e39eb6a8c270 --- /dev/null +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.h @@ -0,0 +1,155 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace normalization_scheduler_utils { + +//! Utility class to iterate candidates of launch configurations in a +//! preferred order. The iteration order is defined as: +//! +//! for bdimx in all valid bdimx in an decreasing order +//! for gdimy in valid gdimy values in an increasing order +//! +//! Each of bdimx and gdimy determines bdimy and gdimx, respecitively, +//! such that the number of threads per block is always 256 and the +//! number of blocks is always equal to the number of SMs. +class PreferredLaunchConfig { + public: + //! Minimum blockDim.x. + static constexpr int kMinBdimx = 8; + //! Maximum blockDim.x. + static constexpr int kMaxBdimx = 16; + + PreferredLaunchConfig(); + + int bdimx() const { + return bdimx_; + } + + int bdimy() const { + return bdimy_; + } + + int gdimx() const { + return gdimxAt(grid_dims_pos_); + } + + int gdimy() const { + return gdimyAt(grid_dims_pos_); + } + + //! Peek the next gdimx. -1 is returned if no further gdimx is available. + int peekNextGdimx() const; + + //! Peek the next gdimy. -1 is returned if no further gdimy is available. + int peekNextGdimy() const; + + //! Move to the next launch configuration. Will be marked as invalid + //! if no valid configuration exists. Return true if successfully moved. + bool moveToNextConfig(); + + //! Try setting blockDim to the next valid config if + //! available. Return false if no valid config exists. gridDim is + //! reset. + bool moveToNextBdim(); + + //! Query if the next configuration will cause blockDim.x to become + //! smaller. + bool isNextSmallerBdimx() const; + + //! Query if blockDim.x can be further lowered + bool canLowerBdimx() const; + + //! Query if no valid configuration is found + bool isInvalid() const { + return !valid_; + } + + private: + //! Populate the list of valid gridDim configurations + void initValidGdims(); + + int gdimxAt(int pos) const { + return valid_grid_dims_.at(pos).first; + } + + int gdimyAt(int pos) const { + return valid_grid_dims_.at(pos).second; + } + + //! Set blockDim.x and in turn blockDim.y. Return true if the + //! specified blockDim.x is successfully set. If dry_run is true, + //! just check if the given config is valid but do not modify the + //! current config. + bool setBdimx(int bdimx, bool dry_run = false); + + void resetGdim() { + grid_dims_pos_ = 0; + } + + void resetBdim() { + // Start with the maximum bdimx and lower it until satisfactory + // config is found + setBdimx(kMaxBdimx); + } + + //! Try setting gridDim to the next valid config if + //! available. Return false if no valid config exists + bool moveToNextGdim(); + + int getNextGdimsPos() const; + + void invalidate() { + valid_ = false; + } + + friend std::ostream& operator<<(std::ostream& os, PreferredLaunchConfig cfg) { + os << "{gdimx: " << cfg.gdimx() << ", gdimy: " << cfg.gdimy() + << ", bdimx: " << cfg.bdimx() << ", bdimy: " << cfg.bdimy() << "}"; + return os; + } + + private: + //! Remember if it is still a valid configuration + bool valid_ = false; + + //! List of valid gridDims ordered by the dimension of + //! gridDim.x. Larger gridDim.x is preferred as it would promote + //! larger independent parallelism + std::vector> valid_grid_dims_; + //! The offset of the Current gridDim in valid_grid_dims_ + int grid_dims_pos_ = 0; + + //! Current blockDim.x + int bdimx_ = 0; + //! Current blockDim.y + int bdimy_ = 0; +}; + +//! Scheduling parameters for grid outer normalization +struct GridOuterNormalizationParams { + LaunchParams launch_params; + int64_t persistent_buffer_factor = -1; + int64_t unswitch_factor = -1; +}; + +std::optional getGridOuterNormalizationParams( + int64_t total_reduction_numel, + int64_t total_iteration_numel, + int64_t vectorize_factor, + int64_t persistent_buffer_size); + +} // namespace normalization_scheduler_utils +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/third_party/nvfuser/csrc/scheduler/reduction.cpp b/third_party/nvfuser/csrc/scheduler/reduction.cpp index 67f35018af762..883630c2d1175 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction.cpp +++ b/third_party/nvfuser/csrc/scheduler/reduction.cpp @@ -435,14 +435,14 @@ std::shared_ptr innerReductionHeuristic( rparams->grid_dim_iter_dom = ParallelType::BIDy; if (godim > scheduler_utils::y_grid_limit) { - rparams->split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom_outer = true; gdimy = std::min(godim, scheduler_utils::y_grid_limit); } } else { rparams->grid_dim_iter_dom = ParallelType::BIDx; if (gdimx > scheduler_utils::x_grid_limit) { - rparams->split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom_outer = true; gdimx = godim; } } @@ -798,7 +798,7 @@ std::shared_ptr outerReductionHeuristic( flip_grid ? ParallelType::BIDy : ParallelType::BIDx; if (gidim > (flip_grid ? scheduler_utils::y_grid_limit : scheduler_utils::x_grid_limit)) { - rparams->split_grid_dim_iter_dom = true; + rparams->split_grid_dim_iter_dom_outer = true; if (flip_grid) { gdimy = scheduler_utils::y_grid_limit; } else { diff --git a/third_party/nvfuser/csrc/scheduler/reduction_heuristic.h b/third_party/nvfuser/csrc/scheduler/reduction_heuristic.h index 5b9b3d313e35c..383c1462ec1a5 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction_heuristic.h +++ b/third_party/nvfuser/csrc/scheduler/reduction_heuristic.h @@ -66,8 +66,10 @@ class ReductionParams : public HeuristicParams { int64_t unroll_factor_iter_dom = 1; // vectorize instead of unroll bool vectorize_iter_dom = false; - // Split grid dim for iteration axis in case it's too large for cuda - bool split_grid_dim_iter_dom = false; + // Inner split grid dim for iteration axis in case it's too large for cuda + bool split_grid_dim_iter_dom_inner = false; + // Outer split grid dim for iteration axis in case it's too large for cuda + bool split_grid_dim_iter_dom_outer = false; // Which block parallel dimension should be used for the iter domain. // !!WARNING!! Convenience method, this be unique based on non-parallel type @@ -100,6 +102,12 @@ class ReductionParams : public HeuristicParams { // parameters, not used for equivalence/hashing. ParallelType grid_dim_outer_reduction = ParallelType::Serial; + // Use computeWith to persistent buffers + bool compute_persistent_buffer_with_first_consumer = false; + + bool static_bdimx = false; + bool static_bdimy = false; + bool isUnrolled() const { return unroll_factor_inner_reduction > 1 || unroll_factor_iter_dom > 1 || unroll_factor_outer_reduction > 1; @@ -133,14 +141,24 @@ class ReductionParams : public HeuristicParams { other.multiple_reds_per_blk == multiple_reds_per_blk && other.unroll_factor_iter_dom == unroll_factor_iter_dom && other.vectorize_iter_dom == vectorize_iter_dom && - other.split_grid_dim_iter_dom == split_grid_dim_iter_dom && + other.split_grid_dim_iter_dom_inner == split_grid_dim_iter_dom_inner && + other.split_grid_dim_iter_dom_outer == split_grid_dim_iter_dom_outer && other.cross_block_outer_reduction == cross_block_outer_reduction && other.cross_grid_outer_reduction == cross_grid_outer_reduction && other.unroll_factor_outer_reduction == unroll_factor_outer_reduction && other.split_grid_dim_outer_reduction == split_grid_dim_outer_reduction && other.batches_per_block_outer_reduction == - batches_per_block_outer_reduction; + batches_per_block_outer_reduction && + other.compute_persistent_buffer_with_first_consumer == + compute_persistent_buffer_with_first_consumer; + + if (other.static_bdimy || static_bdimy) { + attr_equal = attr_equal && other.lparams.bdimy() == lparams.bdimy(); + } + if (other.static_bdimx || static_bdimx) { + attr_equal = attr_equal && other.lparams.bdimx() == lparams.bdimx(); + } return attr_equal; } @@ -179,8 +197,12 @@ class ReductionParams : public HeuristicParams { ss << "\nIteration Domain: "; if (grid_dim_iter_dom != ParallelType::Serial) { - ss << grid_dim_iter_dom << " / " - << (split_grid_dim_iter_dom ? "split grid dimension / " : ""); + ss << grid_dim_iter_dom << " / "; + if (split_grid_dim_iter_dom_outer) { + ss << "split grid dimension outer / "; + } else if (split_grid_dim_iter_dom_inner) { + ss << "split grid dimension inner / "; + } } if (block_dim_iter_dom != ParallelType::Serial) { ss << block_dim_iter_dom << " / "; @@ -217,6 +239,10 @@ class ReductionParams : public HeuristicParams { ss << "factor " << unroll_factor_inner_reduction; } + if (compute_persistent_buffer_with_first_consumer) { + ss << "\ncomputeWith persistent buffers"; + } + ss << "\n" << lparams.toString() << "\n"; ss << "====================================\n"; return ss.str(); @@ -240,12 +266,15 @@ class ReductionParams : public HeuristicParams { static_cast(multiple_reds_per_blk) << (bits - 13) ^ static_cast(unroll_factor_iter_dom) << (bits - 14) ^ static_cast(vectorize_iter_dom) << (bits - 15) ^ - static_cast(split_grid_dim_iter_dom) << (bits - 16) ^ - static_cast(cross_block_outer_reduction) << (bits - 17) ^ - static_cast(cross_grid_outer_reduction) << (bits - 18) ^ - static_cast(split_grid_dim_outer_reduction) << (bits - 19) ^ - static_cast(batches_per_block_outer_reduction) << (bits - 20) ^ - static_cast(unroll_factor_outer_reduction) << (bits - 21); + static_cast(split_grid_dim_iter_dom_outer) << (bits - 16) ^ + static_cast(split_grid_dim_iter_dom_inner) << (bits - 17) ^ + static_cast(cross_block_outer_reduction) << (bits - 18) ^ + static_cast(cross_grid_outer_reduction) << (bits - 19) ^ + static_cast(split_grid_dim_outer_reduction) << (bits - 20) ^ + static_cast(batches_per_block_outer_reduction) << (bits - 21) ^ + static_cast(unroll_factor_outer_reduction) << (bits - 22) ^ + static_cast(compute_persistent_buffer_with_first_consumer) + << (bits - 23); return attr_hash; } diff --git a/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp b/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp index 9b3c70db23fd8..c852e58a6a937 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp @@ -30,6 +30,9 @@ TensorView* scheduleReductionTV( const int outer_reduce_axis = rparams.schedule_3D ? 1 : 0; const int inner_reduce_axis = rparams.schedule_3D ? 2 : has_iter_axis ? 1 : 0; + const bool is_outer_grid_persistence = rparams.persistent_kernel && + rparams.cross_grid_inner_reduction && !rparams.fastest_dim; + TORCH_INTERNAL_ASSERT( (int)reduction_tv->nDims() > std::max(iter_axis, std::max(outer_reduce_axis, inner_reduce_axis)), @@ -64,6 +67,12 @@ TensorView* scheduleReductionTV( reduction_tv->axis(axis + 1)->parallelize(ptype); }; + auto inner_parallel_static = [&reduction_tv]( + int axis, ParallelType ptype, int factor) { + reduction_tv->split(axis, factor); + reduction_tv->axis(axis + 1)->parallelize(ptype); + }; + auto inner_unswitch = [&reduction_tv](int axis) { reduction_tv->split(axis, 1); reduction_tv->axis(axis + 1)->parallelize(ParallelType::Unswitch); @@ -89,7 +98,31 @@ TensorView* scheduleReductionTV( reduction_tv->axis(axis)->parallelize(ParallelType::Unroll); }; - if (rparams.persistent_kernel) { + if (is_outer_grid_persistence) { + const auto reduction_axis = inner_reduce_axis; + TORCH_INTERNAL_ASSERT(rparams.static_bdimy, "blockDim.y must be static"); + inner_parallel_static( + reduction_axis, + rparams.block_dim_inner_reduction, + rparams.lparams.bdimy()); + reduction_tv->split( + reduction_axis, rparams.batches_per_block_inner_reduction); + reduction_tv->axis(reduction_axis) + ->parallelize(rparams.grid_dim_inner_reduction); + // Unswitch the persistent buffer by a factor of + // unroll_factor_inner_reduction. If that is equal to the + // persistent buffer size, unswitch the whole buffer by + // outer-unswith by 1. Otherwise, split the persistent buffer by + // the unsiwtch factor and just unswitch the inner domain + if (rparams.batches_per_block_inner_reduction == + rparams.unroll_factor_inner_reduction) { + outer_unswitch(reduction_axis + 1); + } else { + reduction_tv->split( + reduction_axis + 1, rparams.unroll_factor_inner_reduction); + outer_unswitch(reduction_axis + 2); + } + } else if (rparams.persistent_kernel) { // Persistent Format: // [Grid Split, persistent buffer, unswitch, unroll, thread dim, vectorize] if (rparams.vectorize_inner_reduction) { @@ -115,7 +148,6 @@ TensorView* scheduleReductionTV( if (rparams.pad_inner_reduction_to_warp) { reduction_tv->axis(outer_i)->padToMultipleOfWarp(); } - } else { // Non-persistent format: // [Grid Split, Remainder, unswitch, unroll, thread dim, vectorize] @@ -191,27 +223,59 @@ TensorView* scheduleReductionTV( } if (isParallelTypeThread(rparams.block_dim_iter_dom)) { - inner_parallel(iter_axis, rparams.block_dim_iter_dom); + if (is_outer_grid_persistence) { + TORCH_INTERNAL_ASSERT( + rparams.static_bdimx, "blockDim.x must be static"); + inner_parallel_static( + iter_axis, rparams.block_dim_iter_dom, rparams.lparams.bdimx()); + } else { + inner_parallel(iter_axis, rparams.block_dim_iter_dom); + } } if (!rparams.vectorize_iter_dom && rparams.unroll_factor_iter_dom > 1) { inner_unroll(iter_axis, rparams.unroll_factor_iter_dom); } - if (rparams.unroll_factor_iter_dom > 1) { + // Do not unswitch interation domain in the case of outer grid + // persistence as it's unclear if it's beneficial. + if (rparams.unroll_factor_iter_dom > 1 && !is_outer_grid_persistence) { inner_unswitch(iter_axis); } if (isParallelTypeThread(rparams.grid_dim_iter_dom)) { - if (rparams.split_grid_dim_iter_dom) { + if (rparams.split_grid_dim_iter_dom_outer) { outer_parallel(iter_axis, rparams.grid_dim_iter_dom); + } else if (rparams.split_grid_dim_iter_dom_inner) { + inner_parallel(iter_axis, rparams.grid_dim_iter_dom); } else { reduction_tv->axis(iter_axis)->parallelize(rparams.grid_dim_iter_dom); } } } - return sortAndRFactor(reduction_tv); + auto reduction_rf_tv = sortAndRFactor(reduction_tv); + + // In the case of outer grid persistence, make sure the vectorized + // domain placed at the innermost position. + // TODO: Why isn't this the case by default? + if (is_outer_grid_persistence) { + int vec_id_cur_pos = -1; + std::unordered_map vec_reorder_map; + for (const auto i : c10::irange(reduction_rf_tv->nDims())) { + auto id = reduction_rf_tv->axis(i); + if (id->getParallelType() == ParallelType::Vectorize) { + vec_id_cur_pos = i; + vec_reorder_map[i] = -1; + } else if (vec_id_cur_pos >= 0) { + vec_reorder_map[i] = i - 1; + } + } + TORCH_INTERNAL_ASSERT(vec_id_cur_pos != -1, "Vectorized ID not found"); + reduction_rf_tv->reorder(vec_reorder_map); + } + + return reduction_rf_tv; } namespace { @@ -239,6 +303,44 @@ std::vector addBackBroadcasts( return axes; } +// Check if a reduction is effectively an allreduce. +bool isGridAllreduce(TensorView* reduction_tv) { + // Only Local tensor is converted to allreduce + if (reduction_tv->getMemoryType() != MemoryType::Local) { + return false; + } + + // Collect all reduction parallel types + ParallelTypeBitmap reduction_parallel_types; + std::for_each( + reduction_tv->domain()->domain().begin(), + reduction_tv->domain()->domain().end(), + [&](auto id) { + if (id->isReduction() && + isParallelTypeBlockDim(id->getParallelType())) { + reduction_parallel_types.set(id->getParallelType()); + } + }); + + // If any of the reduction parallel types is used to parallelize + // the broadcast, it will be converted to an allreduce reduction expr + for (auto bcast_expr : + ir_utils::filterByType(reduction_tv->uses())) { + auto bcast_tv = bcast_expr->out()->as(); + if (std::any_of( + bcast_tv->domain()->domain().begin(), + bcast_tv->domain()->domain().end(), + [&](auto bcast_id) { + auto pt = bcast_id->getParallelType(); + return isParallelTypeBlockDim(pt) && + reduction_parallel_types.get(pt); + })) { + return true; + } + } + return false; +} + } // namespace void multiReductionInliner( @@ -250,6 +352,9 @@ void multiReductionInliner( std::vector cached_inputs, std::vector> cached_outputs, std::vector dummy_outputs) { + const bool is_outer_grid_persistence = rparams.persistent_kernel && + rparams.cross_grid_inner_reduction && !rparams.fastest_dim; + // Propagate transformations before we rfactor the other reductions TransformPropagator propagator(reference_tv); MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); @@ -276,7 +381,8 @@ void multiReductionInliner( } for (auto reduction_tv_ : reduction_tvs) { - if (reduction_tv_ == reduction_tv) { + if (reduction_tv_ == reduction_tv || + reduction_tv_->definition()->isA()) { // This should come in already rfactored continue; } else { @@ -359,18 +465,44 @@ void multiReductionInliner( std::vector rfactor_and_reduction_tvs = { reference_tv, reduction_tv}; // If reference shouldn't be unrolled, clear that parallel type. + // In the case of outer grid persistence, replace Vector with Group for (auto tv : rfactor_and_reduction_tvs) { if (are_unrolled.count(tv) == 0) { for (const auto i : c10::irange(tv->nDims())) { auto id = tv->axis((int)i); - if (id->getParallelType() == ParallelType::Unroll || + // Use Group only for grid reductions (i.e., not for rfactor'ed + // reductions) + if (is_outer_grid_persistence && + std::find(reduction_tvs.begin(), reduction_tvs.end(), tv) != + reduction_tvs.end() && + id->getParallelType() == ParallelType::Vectorize) { + tv->axis((int)i)->parallelize(ParallelType::Group); + for (auto sibling : ir_utils::siblingTvsOf(tv)) { + sibling->axis((int)i)->parallelize(ParallelType::Group); + } + } else if ( + id->getParallelType() == ParallelType::Unroll || id->getParallelType() == ParallelType::Vectorize || id->getParallelType() == ParallelType::MisalignedVectorize) { tv->axis((int)i)->parallelize(ParallelType::Serial); + for (auto sibling : ir_utils::siblingTvsOf(tv)) { + sibling->axis((int)i)->parallelize(ParallelType::Serial); + } } } } } + + std::vector allreduce_tvs; + std::copy_if( + reduction_tvs.begin(), + reduction_tvs.end(), + std::back_inserter(allreduce_tvs), + [&](auto tv) { return reduction_tv != tv && isGridAllreduce(tv); }); + if (!allreduce_tvs.empty()) { + scheduler_utils::parallelizeAllLike( + reduction_tv, -1, allreduce_tvs, {ParallelType::Group}); + } } // Remove dummy outputs as they can inadvertently affect CA positions @@ -538,7 +670,8 @@ std::vector projectPersistentBuffers(Fusion* fusion) { const auto& projected_buffers = persistent_info.projectable_persistent_buffers; - TORCH_INTERNAL_ASSERT(persistent_buffers.size() == persistent_buffers.size()); + TORCH_INTERNAL_ASSERT( + persistent_buffers.size() == persistent_resolution_points.size()); // Iterate through projected buffers, tracking which index it corresponds too // since there's a resolution point entry for every buffer. diff --git a/third_party/nvfuser/csrc/scheduler/registry.cpp b/third_party/nvfuser/csrc/scheduler/registry.cpp index 9bd6706ae42e8..95033a0c12f42 100644 --- a/third_party/nvfuser/csrc/scheduler/registry.cpp +++ b/third_party/nvfuser/csrc/scheduler/registry.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -1737,6 +1738,13 @@ class PersistentKernelScheduler : public SchedulerEntry { }); auto& reduction_tvs = reduction_tv_entry.get(); + auto properties = + scheduler_utils::getProperties(fusion, runtime_info, reduction_tvs[0]); + + if (!properties.fastest_dim_reduction) { + return canScheduleRunTimeOuter( + fusion, runtime_info, data_cache, reduction_tvs, properties); + } auto persistent_buffer_info_entry = HeuristicSummaryEntry( @@ -1761,18 +1769,13 @@ class PersistentKernelScheduler : public SchedulerEntry { const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - // If there's a small iteration dimension but a large reduction dimension it - // may not make sense to make a persistent kernel - auto properties = - scheduler_utils::getProperties(fusion, runtime_info, reduction_tvs[0]); - // TODO: Enable grid persistence const auto available_persistent_buffer_size = scheduler_utils::register_file_size; if (persistent_buffer_size > available_persistent_buffer_size) { scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Persistent, "no enough registers for persistece"); + ScheduleHeuristic::Persistent, "not enough registers for persistece"); return false; } @@ -1839,6 +1842,203 @@ class PersistentKernelScheduler : public SchedulerEntry { params_ = getPersistentHeuristics(fusion, runtime_info, data_cache); TORCH_INTERNAL_ASSERT(params_ != nullptr); } + + static bool canScheduleRunTimeOuter( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache, + const std::vector& reduction_tvs, + const scheduler_utils::TvProperties& properties) { + FUSER_PERF_SCOPE("PersistentKernelScheduler::canScheduleRuntimeOuter"); + FusionGuard fg(fusion); + + const auto device_prop = at::cuda::getCurrentDeviceProperties(); + + const int64_t sm_register_file_size = + static_cast(device_prop->regsPerBlock * sizeof(int)); + + auto persistent_buffer_info_entry = + HeuristicSummaryEntry( + data_cache, [&fusion]() { + return std::make_unique( + scheduler_utils::persistentBuffers(fusion)); + }); + + const auto& persistent_buffer_info = persistent_buffer_info_entry.get(); + + auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( + fusion, runtime_info, persistent_buffer_info, data_cache); + + // Note that projected buffer size can be zero + auto persistent_buffer_size = + persistent_buffer_size_info.projected_persistent_buffer_size == 0 + ? persistent_buffer_size_info.persistent_buffer_size + : std::min( + persistent_buffer_size_info.persistent_buffer_size, + persistent_buffer_size_info.projected_persistent_buffer_size); + + const int64_t device_multiprocessor_count = + (int64_t)device_prop->multiProcessorCount; + + const auto available_persistent_buffer_size = + sm_register_file_size * device_multiprocessor_count; + + if (persistent_buffer_size > available_persistent_buffer_size) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, + "not enough registers for persistence"); + return false; + } + + const int64_t vectorization_factor = + vectorize_helper::getVectorizationFactor( + runtime_info, + reduction_tvs.at(0), + data_cache, + reduction_tvs.at(0)->nDims() - + properties.inner_most_dimension_ndims); + + // Minimum required multi reduction factor. + const int64_t min_multi_reduction_factor = vectorization_factor * + normalization_scheduler_utils::PreferredLaunchConfig::kMinBdimx; + + const int64_t required_sm_per_norm = ceilDiv( + persistent_buffer_size * min_multi_reduction_factor, + sm_register_file_size); + + // If the persistence requires over half the device don't do grid + // persistence as we can't overlap the grid comms. + if (required_sm_per_norm > + scheduler_utils::safeDiv(device_multiprocessor_count, 2)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, + "requires over half GPU persistence.", + " required SMs per normalization: ", + required_sm_per_norm); + return false; + } + + const bool is_cross_grid = required_sm_per_norm > 1; + + std::optional + cross_grid_params; + + if (is_cross_grid) { + // Don't try to be persistent unless at least 4-way vectorized + // as register usage is hard to control + // TODO: Is this necessary for block persistence as well? + if (vectorization_factor < 4) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "not enough vectorized"); + return false; + } + + // Make sure there's a valid grid persistence launch config + cross_grid_params = + normalization_scheduler_utils::getGridOuterNormalizationParams( + properties.total_reduction_numel, + properties.total_iteration_numel, + vectorization_factor, + persistent_buffer_size); + + if (!cross_grid_params.has_value()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "no valid launch config found"); + return false; + } + } + + // Maximum number of iteration dimensions we can have and still be + // persistent. + const int64_t max_multi_reduction_factor = scheduler_utils::safeDiv( + is_cross_grid ? available_persistent_buffer_size + : sm_register_file_size, + persistent_buffer_size); + + // Don't go persistent if we can't fit the minimum multi reduction + // factor + if (max_multi_reduction_factor < min_multi_reduction_factor) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, + "Not enough threads.", + " Multi reduction factor, ", + max_multi_reduction_factor, + ", is smaller than minimum multi reduction factor, ", + min_multi_reduction_factor); + return false; + } + + const int64_t max_used_sms = is_cross_grid + ? ceilDiv( + ceilDiv(properties.total_iteration_numel, vectorization_factor), + cross_grid_params->launch_params.bdimx()) * + cross_grid_params->launch_params.gdimy() + : ceilDiv( + properties.total_iteration_numel * persistent_buffer_size, + sm_register_file_size); + + // Bandwidth suffers if the number of used SMs is small. This is + // particularly impactful in the case of cross grid, so at least + // half of the SMs are required to be used. In the case of cross + // block, keep using the existing heuristics for now. + if (is_cross_grid && + max_used_sms < + scheduler_utils::safeDiv(device_multiprocessor_count, 2)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, + "cross grid - not enough used SMs: ", + max_used_sms); + return false; + } + + const int64_t device_max_threads_per_multiprocessor = + (int64_t)device_prop->maxThreadsPerMultiProcessor; + const int64_t min_fraction_of_sms = + scheduler_utils::safeDiv(device_multiprocessor_count, 8); + if (properties.total_reduction_numel >= + device_max_threads_per_multiprocessor * 4 && // Large reduction dim + max_used_sms < min_fraction_of_sms) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "not enough used SMs"); + return false; + } + + // The runtime kernel for grouped normal grid reductions is not + // well tuned, and it turned out to be quite difficult to get + // consistently better performances than non-persistent + // schedules. Disabled for now. + // TODO: Enable non-welford persistent reductions + if (is_cross_grid && + std::any_of( + reduction_tvs.begin(), + reduction_tvs.end(), + [](TensorView* reduction_tv) { + return !reduction_tv->definition()->isA(); + })) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "non-Welford not enabled yet"); + return false; + } + + // Had a hard time tuning on Titan RTX and V100 when the iteration + // space is not evenly divided by threads and thread blocks. It + // doesn't seem to be noticeably bad on A100, though. For now, + // disable the schedule if not evenly divisible on Titan RTX and + // V100, i.e., compute architecture version 7. + // TODO: Revisit + if (is_cross_grid && + (properties.total_iteration_numel % + (vectorization_factor * cross_grid_params->launch_params.bdimx() * + cross_grid_params->launch_params.gdimx()) != + 0) && + device_prop->major == 7) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "iteration not evenly divided"); + return false; + } + + return true; + } }; // Schedule Table diff --git a/third_party/nvfuser/csrc/scheduler/utils.cpp b/third_party/nvfuser/csrc/scheduler/utils.cpp index a111722978f66..9c3fa22c9cb74 100644 --- a/third_party/nvfuser/csrc/scheduler/utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/utils.cpp @@ -624,7 +624,7 @@ namespace { std::unique_ptr getScopePersistenceFactors( Fusion* fusion, - PersistentBufferInfo& persistent_buffer_info) { + const PersistentBufferInfo& persistent_buffer_info) { auto new_persistent_factor_map_ptr = std::make_unique(); auto& new_persistent_factor_map = *new_persistent_factor_map_ptr; @@ -744,7 +744,7 @@ getScopePersistenceFactors( PersistentBufferSizeReturn persistentBufferSize( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - PersistentBufferInfo& persistent_buffer_info, + const PersistentBufferInfo& persistent_buffer_info, HeuristicSummary* data_cache) { FUSER_PERF_SCOPE("scheduler_utils::persistentBufferSize"); diff --git a/third_party/nvfuser/csrc/scheduler/utils.h b/third_party/nvfuser/csrc/scheduler/utils.h index 6177c3af16dd2..854bdc016d5bb 100644 --- a/third_party/nvfuser/csrc/scheduler/utils.h +++ b/third_party/nvfuser/csrc/scheduler/utils.h @@ -188,7 +188,7 @@ struct PersistentBufferSizeReturn { TORCH_CUDA_CU_API PersistentBufferSizeReturn persistentBufferSize( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - PersistentBufferInfo& persistent_buffers, + const PersistentBufferInfo& persistent_buffers, HeuristicSummary* data_cache = nullptr); // Merges tensor view to the form: diff --git a/third_party/nvfuser/csrc/utils.cpp b/third_party/nvfuser/csrc/utils.cpp index 2ac0c92f5eb9b..29e8c4e29e08e 100644 --- a/third_party/nvfuser/csrc/utils.cpp +++ b/third_party/nvfuser/csrc/utils.cpp @@ -117,6 +117,7 @@ auto parseDebugDumpOptions() { {"ptxas_verbose", DebugDumpOption::PrintPtxasLog}, {"buffer_reuse_verbose", DebugDumpOption::BufferReuseInfo}, {"scheduler_params", DebugDumpOption::SchedulerDebug}, + {"scheduler_verbose", DebugDumpOption::SchedulerVerbose}, {"parallel_dimensions", DebugDumpOption::ParallelDimensions}, {"halo", DebugDumpOption::Halo}, {"perf_debug_verbose", DebugDumpOption::PerfDebugVerbose}, diff --git a/third_party/nvfuser/csrc/utils.h b/third_party/nvfuser/csrc/utils.h index 0b44516e23254..e37ce57d244c7 100644 --- a/third_party/nvfuser/csrc/utils.h +++ b/third_party/nvfuser/csrc/utils.h @@ -55,6 +55,7 @@ enum class DebugDumpOption { PrintPtxasLog, //!< Print the ptxas verbose log including register usage BufferReuseInfo, //!< Dump the analysis details of local/shared buffer re-use SchedulerDebug, //! Dump scheduler heuristic parameters + SchedulerVerbose, //! Dump detailed scheduler logging ParallelDimensions, //!< Dump known parallel dimensions Halo, //! Halo information of tensors PerfDebugVerbose, //! When running kernels, print verbose information diff --git a/third_party/nvfuser/test/test_gpu_outer_reduction.cpp b/third_party/nvfuser/test/test_gpu_outer_reduction.cpp index 2d569b2e85ff0..2f62d795b25ac 100644 --- a/third_party/nvfuser/test/test_gpu_outer_reduction.cpp +++ b/third_party/nvfuser/test/test_gpu_outer_reduction.cpp @@ -1343,5 +1343,866 @@ TEST_F( grid_persistent_batchnorm_bwd_manual(256, 28, 512, DataType::Float); } +//////////////////////////////////////////////////////////////// +/// Scheduler tests +//////////////////////////////////////////////////////////////// + +namespace { + +TensorView* cast(TensorView* tv, DataType dtype) { + if (tv->getDataType() != dtype) { + return castOp(dtype, tv); + } else { + return tv; + } +} + +bool shouldBePersistent( + int64_t N, + int64_t HW, + DataType dtype, + bool is_bwd, + bool use_weights = false, + DataType weights_dtype = DataType::Float) { + // Non-welford is disabled for now + if (is_bwd) { + return false; + } + + const int64_t vec_factor = 16 / + std::max(dataTypeSize(dtype), + (use_weights ? dataTypeSize(weights_dtype) : 1)); + + const int64_t num_threads = 256; + const int64_t min_bdimx = 8; + const int64_t max_bdimy = num_threads / min_bdimx; + const int64_t max_gdimy = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount / 2; + const int64_t pb_factor = ceilDiv(ceilDiv(N * HW * HW, max_bdimy), max_gdimy); + const auto req_reg_count = pb_factor * vec_factor * dataTypeSize(dtype) / + sizeof(int) * + (is_bwd ? 2 : 1); // Two tensors are cached in the backward batchnorm + + // The scheduler sets aside (pb_factor + 35) registers + return req_reg_count <= 255 - (pb_factor + 35); +} + +} // namespace + +// TODO: Enable once non-welford grid reductions are supported +#if 0 +namespace { + +// Forward grid reduction +void grid_persistent_reduction_outer_norm_like_scheduler( + int64_t N, + int64_t HW, + int64_t C, + DataType dtype, + bool use_weights = false, + DataType weights_dtype = DataType::Float) { + const bool benchmark_mode = isBenchmarkMode(); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector bcast_pattern{true, true, true, false}; + std::vector reduction_dims{2, 1, 0}; + + auto inp = makeContigTensor(4, dtype); + fusion.addInput(inp); + + TensorView* weights = nullptr; + if (use_weights) { + weights = makeContigTensor(1, weights_dtype); + fusion.addInput(weights); + } + + auto inp_cast = cast(inp, DataType::Float); + auto inp_allreduce = broadcast(sum(inp_cast, reduction_dims), bcast_pattern); + auto out = sub(inp_cast, inp_allreduce); + + if (use_weights) { + out = add(out, broadcast(cast(weights, DataType::Float), bcast_pattern)); + } + + out = cast(out, dtype); + fusion.addOutput(out); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto options_weight = at::TensorOptions() + .dtype(data_type_to_aten(weights_dtype)) + .device(at::kCUDA, 0); + at::manual_seed(0); + + const std::vector input_shape{N, HW, HW, C}; + auto t0 = at::randn(input_shape, options); + auto t1 = at::randn({C}, options_weight); + std::vector aten_inputs({t0}); + if (use_weights) { + aten_inputs.push_back(t1); + } + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + + if (!shouldBePersistent(N, HW, dtype, false, use_weights, weights_dtype)) { + TORCH_CHECK(runtime->isSegmented(), "Expected to be segmented"); + } else { + TORCH_CHECK( + !runtime->isSegmented(), + "Unexpected number of segments: ", + runtime->fusionSegments()->groups().size()); + + const auto& scheduler_entry = + runtime->schedulerHeuristics()->heuristicsList().at(0); + TORCH_CHECK( + scheduler_entry->heuristic() == ScheduleHeuristic::Persistent, + "Unexpected heuristic was chosen: ", + scheduler_entry->heuristic()); + + if (benchmark_mode) { + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + } + } + } + + auto t0_cast = t0.to(at::kFloat); + auto t0_allreduce = + t0_cast.sum({0, 1, 2}).unsqueeze(0).unsqueeze(0).unsqueeze(0); + auto ref = t0_cast - t0_allreduce; + if (use_weights) { + ref = ref + t1.to(at::kFloat); + } + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__, ""); +} + +} // namespace + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormLikeHalf256x7x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 7, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormLikeHalf256x14x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 14, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormLikeHalf256x28x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 28, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormLikeFloat256x7x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 7, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormLikeFloat256x14x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 14, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormLikeFloat256x28x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 28, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormWithWeightsLikeHalf256x7x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 7, 512, DataType::Half, true, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormWithWeightsLikeHalf256x14x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 14, 512, DataType::Half, true, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormWithWeightsLikeHalf256x28x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_like_scheduler( + 256, 28, 512, DataType::Half, true, DataType::Float); +} +#endif + +namespace { + +// Forward welford +void grid_persistent_welford_outer_norm_like_scheduler( + int64_t N, + int64_t HW, + int64_t C, + DataType dtype, + bool use_weights = false, + DataType weights_dtype = DataType::Float) { + const bool benchmark_mode = isBenchmarkMode(); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector bcast_pattern{true, true, true, false}; + std::vector reduction_dims{2, 1, 0}; + + auto inp = makeContigTensor(4, dtype); + fusion.addInput(inp); + + TensorView* weights = nullptr; + if (use_weights) { + weights = makeContigTensor(1, weights_dtype); + fusion.addInput(weights); + } + + auto inp_cast = cast(inp, DataType::Float); + auto inp_allreduce = + broadcast(Welford(inp_cast, reduction_dims).avg, bcast_pattern); + auto out = sub(inp_cast, inp_allreduce); + + if (use_weights) { + out = add(out, broadcast(cast(weights, DataType::Float), bcast_pattern)); + } + + out = cast(out, dtype); + fusion.addOutput(out); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto options_weight = at::TensorOptions() + .dtype(data_type_to_aten(weights_dtype)) + .device(at::kCUDA, 0); + at::manual_seed(0); + + const std::vector input_shape{N, HW, HW, C}; + auto t0 = at::randn(input_shape, options); + auto t1 = at::randn({C}, options_weight); + std::vector aten_inputs({t0}); + if (use_weights) { + aten_inputs.push_back(t1); + } + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + + if (!shouldBePersistent(N, HW, dtype, false, use_weights, weights_dtype)) { + TORCH_CHECK(runtime->isSegmented(), "Expected to be segmented"); + } else { + TORCH_CHECK( + !runtime->isSegmented(), + "Unexpected number of segments: ", + runtime->fusionSegments()->groups().size()); + + const auto& scheduler_entry = + runtime->schedulerHeuristics()->heuristicsList().at(0); + TORCH_CHECK( + scheduler_entry->heuristic() == ScheduleHeuristic::Persistent, + "Unexpected heuristic was chosen: ", + scheduler_entry->heuristic()); + + if (benchmark_mode) { + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + } + } + } + + auto t0_cast = t0.to(at::kFloat); + auto t0_allreduce = + t0_cast.mean({0, 1, 2}).unsqueeze(0).unsqueeze(0).unsqueeze(0); + auto ref = t0_cast - t0_allreduce; + if (use_weights) { + ref = ref + t1.to(at::kFloat); + } + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__, ""); +} + +} // namespace + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormLikeHalf256x7x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 7, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormLikeHalf256x14x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 14, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormLikeHalf256x28x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 28, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormLikeFloat256x7x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 7, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormLikeFloat256x14x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 14, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormLikeFloat256x28x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 28, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormWithWeithtsLikeHalf256x7x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 7, 512, DataType::Half, true, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormWithWeightsLikeWHalf256x14x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 14, 512, DataType::Half, true, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentWelfordOuterNormWithWeightsLikeWHalf256x28x512Scheduler_CUDA) { + grid_persistent_welford_outer_norm_like_scheduler( + 256, 28, 512, DataType::Half, true, DataType::Float); +} + +namespace { + +// Forward batchnorm +void grid_persistent_batchnorm_scheduler( + int64_t N, + int64_t HW, + int64_t C, + DataType dtype) { + const bool benchmark_mode = isBenchmarkMode(); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + const bool kTraining = true; + const float kMomentum = 0.1; + const float kEps = 1e-5; + + // setup fusion + auto input = makeContigTensor(4, dtype); + auto weight = makeContigTensor(1, dtype); + auto bias = makeContigTensor(1, dtype); + auto running_mean = makeContigTensor(1, DataType::Float); + auto running_var = makeContigTensor(1, DataType::Float); + + fusion_ptr->addInput(input); + fusion_ptr->addInput(weight); + fusion_ptr->addInput(bias); + fusion_ptr->addInput(running_mean); + fusion_ptr->addInput(running_var); + + if (dtype == DataType::Half) { + input = castOp(DataType::Float, input); + weight = castOp(DataType::Float, weight); + bias = castOp(DataType::Float, bias); + } + + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); + + auto result = batch_norm( + input, + weight, + bias, + running_mean, + running_var, + kTraining, + momentum_ptr, + eps_ptr, + true); + + auto output = result.output; + + if (dtype == DataType::Half) { + output = castOp(DataType::Half, output); + } + + fusion_ptr->addOutput(output); + + auto options_float = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::manual_seed(0); + + auto at_input = at::randn({N, C, HW, HW}, options) + .contiguous(c10::MemoryFormat::ChannelsLast); + auto at_input_nvfuser = at_input.clone().detach().permute({0, 2, 3, 1}); + + auto at_weight = at::randn({C}, options); + auto at_bias = at::randn({C}, options); + auto at_running_mean = at::randn({C}, options_float); + auto at_running_var = at::randn({C}, options_float); + + std::vector aten_inputs( + {at_input_nvfuser, at_weight, at_bias, at_running_mean, at_running_var}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + + if (!shouldBePersistent(N, HW, dtype, false, true, DataType::Float)) { + TORCH_CHECK(runtime->isSegmented(), "Expected to be segmented"); + } else { + TORCH_CHECK( + !runtime->isSegmented(), + "Unexpected number of segments: ", + runtime->fusionSegments()->groups().size()); + + const auto& scheduler_entry = + runtime->schedulerHeuristics()->heuristicsList().at(0); + TORCH_CHECK( + scheduler_entry->heuristic() == ScheduleHeuristic::Persistent, + "Unexpected heuristic was chosen: ", + scheduler_entry->heuristic()); + + if (benchmark_mode) { + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + } + } + } + + auto at_output = at::batch_norm( + at_input, + at_weight, + at_bias, + at_running_mean, + at_running_var, + kTraining, + kMomentum, + kEps, + true); + + cg_outputs.at(0) = cg_outputs.at(0).permute({0, 3, 1, 2}); + + testValidate( + &fusion, cg_outputs, aten_inputs, {at_output}, __LINE__, __FILE__, ""); +} + +} // namespace + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastHalf256x7x512Scheduler_CUDA) { + grid_persistent_batchnorm_scheduler(256, 7, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastHalf256x14x512Scheduler_CUDA) { + grid_persistent_batchnorm_scheduler(256, 14, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastHalf256x28x512Scheduler_CUDA) { + grid_persistent_batchnorm_scheduler(256, 28, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastFloat256x7x512Scheduler_CUDA) { + grid_persistent_batchnorm_scheduler(256, 7, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastFloat256x14x512Scheduler_CUDA) { + grid_persistent_batchnorm_scheduler(256, 14, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastFloat256x28x512Scheduler_CUDA) { + grid_persistent_batchnorm_scheduler(256, 28, 512, DataType::Float); +} + +// TODO: Enable once non-welford grid reductions are supported +#if 0 +namespace { + +// Backward grid reduction +void grid_persistent_reduction_outer_norm_bwd_like_scheduler( + int64_t N, + int64_t HW, + int64_t C, + DataType dtype) { + const bool benchmark_mode = isBenchmarkMode(); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector bcast_pattern{true, true, true, false}; + std::vector reduction_dims{2, 1, 0}; + + // grad_output + auto tv0 = makeContigTensor(4, dtype); + fusion.addInput(tv0); + // input + auto tv1 = makeContigTensor(4, dtype); + fusion.addInput(tv1); + + auto norm = + IrBuilder::create(1.0 / ((double)N * (double)HW * (double)HW)); + + auto tv2 = dtype == DataType::Half ? castOp(DataType::Float, tv0) : tv0; + auto tv3 = dtype == DataType::Half ? castOp(DataType::Float, tv1) : tv1; + // grad_output_sum-like pattern + auto tv4 = sum(tv2, reduction_dims); + auto tv5 = mul(tv4, norm); + auto tv6 = broadcast(tv5, bcast_pattern); + // dot_p-like pattern + auto tv7 = sub(tv2, tv3); + auto tv8 = sum(tv7, reduction_dims); + auto tv9 = mul(tv8, norm); + auto tv10 = broadcast(tv9, bcast_pattern); + + auto tv11 = mul(tv3, tv10); + auto tv12 = sub(tv2, tv11); + auto tv13 = sub(tv12, tv6); + auto tv14 = dtype == DataType::Half ? castOp(DataType::Half, tv13) : tv13; + fusion.addOutput(tv14); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::manual_seed(0); + + const std::vector input_shape{N, HW, HW, C}; + auto t0 = at::randn(input_shape, options); + auto t1 = at::randn(input_shape, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + + if (!shouldBePersistent(N, HW, dtype, true)) { + TORCH_CHECK(runtime->isSegmented(), "Expected to be segmented"); + } else { + TORCH_CHECK( + !runtime->isSegmented(), + "Unexpected number of segments: ", + runtime->fusionSegments()->groups().size()); + + const auto& scheduler_entry = + runtime->schedulerHeuristics()->heuristicsList().at(0); + TORCH_CHECK( + scheduler_entry->heuristic() == ScheduleHeuristic::Persistent, + "Unexpected heuristic was chosen: ", + scheduler_entry->heuristic()); + + if (benchmark_mode) { + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + } + } + } + + auto norm_double = 1.0 / ((double)N * (double)HW * (double)HW); + auto t4 = t0.to(at::kFloat); + auto t5 = t1.to(at::kFloat); + auto t6 = sum(t4, {0, 1, 2}); + auto t7 = t6 * norm_double; + auto t8 = t7.unsqueeze(0).unsqueeze(0).unsqueeze(0); + auto t9 = t4 - t5; + auto t10 = sum(t9, {0, 1, 2}); + auto t11 = t10 * norm_double; + auto t12 = t11.unsqueeze(0).unsqueeze(0).unsqueeze(0); + + // Second use of manually projected persistent buffer + auto t13 = t0.to(at::kFloat); + auto t14 = t1.to(at::kFloat); + auto t15 = t14 * t12; + auto t16 = t13 - t15; + auto t17 = t16 - t8; + + testValidate(&fusion, cg_outputs, aten_inputs, {t17}, __LINE__, __FILE__, ""); +} + +} // namespace + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormBwdLikeHalf256x7x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_bwd_like_scheduler( + 256, 7, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormBwdLikeHalf256x14x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_bwd_like_scheduler( + 256, 14, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormBwdLikeHalf256x28x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_bwd_like_scheduler( + 256, 28, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormBwdLikeFloat256x7x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_bwd_like_scheduler( + 256, 7, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormBwdLikeFloat256x14x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_bwd_like_scheduler( + 256, 14, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentReductionOuterNormBwdLikeFloat256x28x512Scheduler_CUDA) { + grid_persistent_reduction_outer_norm_bwd_like_scheduler( + 256, 28, 512, DataType::Float); +} + +namespace { + +// Backward batchnorm +void grid_persistent_batchnorm_bwd_scheduler( + int64_t N, + int64_t HW, + int64_t C, + DataType dtype) { + const bool benchmark_mode = isBenchmarkMode(); + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + const bool kTraining = true; + const float kEps = 1e-5; + + // setup fusion + auto input = makeContigTensor(4, dtype); + auto grad_output = makeContigTensor(4, dtype); + auto weight = makeContigTensor(1, DataType::Float); + auto running_mean = makeContigTensor(1, DataType::Float); + auto running_var = makeContigTensor(1, DataType::Float); + auto save_mean = makeContigTensor(1, DataType::Float); + auto save_var = makeContigTensor(1, DataType::Float); + + fusion.addInput(input); + fusion.addInput(grad_output); + fusion.addInput(weight); + fusion.addInput(running_mean); + fusion.addInput(running_var); + fusion.addInput(save_mean); + fusion.addInput(save_var); + + if (dtype == DataType::Half) { + input = castOp(DataType::Float, input); + grad_output = castOp(DataType::Float, grad_output); + } + + auto eps_ptr = IrBuilder::create(kEps); + + auto result = batch_norm_backward( + input, + grad_output, + weight, + running_mean, + running_var, + save_mean, + save_var, + kTraining, + eps_ptr, + std::vector(3, true), + true); + + auto grad_input = result.grad_input; + auto grad_weight = result.grad_weight; + auto grad_bias = result.grad_bias; + + if (dtype == DataType::Half) { + grad_input = castOp(DataType::Half, grad_input); + grad_weight = castOp(DataType::Half, grad_weight); + grad_bias = castOp(DataType::Half, grad_bias); + } + + fusion.addOutput(grad_input); + fusion.addOutput(grad_weight); + fusion.addOutput(grad_bias); + + auto options_float = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::manual_seed(0); + + const std::vector input_shape{N, HW, HW, C}; + + auto at_input = at::randn({N, C, HW, HW}, options) + .contiguous(c10::MemoryFormat::ChannelsLast); + auto at_input_nvfuser = at_input.clone().detach().permute({0, 2, 3, 1}); + + auto at_grad_out = at::randn({N, C, HW, HW}, options) + .contiguous(c10::MemoryFormat::ChannelsLast); + auto at_grad_out_nvfuser = at_grad_out.clone().detach().permute({0, 2, 3, 1}); + + at::Tensor at_weight = at::ones({C}, options_float); + at::Tensor at_run_mean = at::zeros({C}, options_float); + at::Tensor at_run_var = at::ones({C}, options_float); + at::Tensor at_save_mean = at::zeros({C}, options_float); + at::Tensor at_save_var = at::ones({C}, options_float); + + std::vector aten_inputs( + {at_input_nvfuser, + at_grad_out_nvfuser, + at_weight, + at_run_mean, + at_run_var, + at_save_mean, + at_save_var}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + + if (!shouldBePersistent(N, HW, dtype, true, true, DataType::Float)) { + TORCH_CHECK(runtime->isSegmented(), "Expected to be segmented"); + } else { + TORCH_CHECK( + !runtime->isSegmented(), + "Unexpected number of segments: ", + runtime->fusionSegments()->groups().size()); + + const auto& scheduler_entry = + runtime->schedulerHeuristics()->heuristicsList().at(0); + TORCH_CHECK( + scheduler_entry->heuristic() == ScheduleHeuristic::Persistent, + "Unexpected heuristic was chosen: ", + scheduler_entry->heuristic()); + + if (benchmark_mode) { + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + } + } + } + + // Permute grad_input output + cg_outputs.at(0) = cg_outputs.at(0).permute({0, 3, 1, 2}); + + auto at_output = at::native_batch_norm_backward( + at_grad_out, + at_input, + at_weight, + at_run_mean, + at_run_var, + at_save_mean, + at_save_var, + true, + kEps, + {true, true, true}); + + testValidate( + &fusion, + cg_outputs, + aten_inputs, + {std::get<0>(at_output), std::get<1>(at_output), std::get<2>(at_output)}, + __LINE__, + __FILE__, + ""); +} + +} // namespace + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastBwdHalf256x7x512Scheduler_CUDA) { + grid_persistent_batchnorm_bwd_scheduler(256, 7, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastBwdHalf256x14x512Scheduler_CUDA) { + grid_persistent_batchnorm_bwd_scheduler(256, 14, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastBwdHalf256x28x512Scheduler_CUDA) { + grid_persistent_batchnorm_bwd_scheduler(256, 28, 512, DataType::Half); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastBwdFloat256x7x512Scheduler_CUDA) { + grid_persistent_batchnorm_bwd_scheduler(256, 7, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastBwdFloat256x14x512Scheduler_CUDA) { + grid_persistent_batchnorm_bwd_scheduler(256, 14, 512, DataType::Float); +} + +TEST_F( + NVFuserTest, + FusionGridPersistentBatchNormChannelsLastBwdFloat256x28x512Scheduler_CUDA) { + grid_persistent_batchnorm_bwd_scheduler(256, 28, 512, DataType::Float); +} +#endif + } // namespace jit } // namespace torch