Skip to content

Commit

Permalink
Enable grid outer persistent scheduling (#2435)
Browse files Browse the repository at this point in the history
* Enable grid outer persistent scheduling
  • Loading branch information
naoyam authored Feb 15, 2023
1 parent 084e340 commit 7b37a83
Show file tree
Hide file tree
Showing 15 changed files with 2,093 additions and 66 deletions.
1 change: 1 addition & 0 deletions third_party/nvfuser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
Expand Down
5 changes: 4 additions & 1 deletion third_party/nvfuser/benchmark/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ std::string toString(const ReductionParams& rparams) {
ss << " // Iteration Domain: "
<< (rparams.multiple_reds_per_blk ? "multiple reductions per block / "
: "")
<< (rparams.split_grid_dim_iter_dom ? "split grid dimension / " : "")
<< ((rparams.split_grid_dim_iter_dom_inner ||
rparams.split_grid_dim_iter_dom_outer)
? "split grid dimension / "
: "")
<< (rparams.vectorize_iter_dom ? "vectorize / " : "")
<< (rparams.unroll_factor_iter_dom > 1 && !rparams.vectorize_iter_dom
? "unroll / "
Expand Down
27 changes: 27 additions & 0 deletions third_party/nvfuser/csrc/scheduler/debug_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once

#include <utils.h>

#include <iostream>

namespace torch {
namespace jit {
namespace fuser {
Expand All @@ -26,6 +30,29 @@ void canScheduleRejectReason(HeuristicType heuristic, const Args&... args) {
"Scheduler _", heuristic, "_ ***rejected*** because : ", args...);
}

// Based on
// https://learn.microsoft.com/en-us/cpp/cpp/ellipses-and-variadic-templates?view=msvc-170#example
inline void log() {
if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) {
std::cerr << std::endl;
}
}

template <typename T>
void log(const T& t) {
if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) {
std::cerr << t << std::endl;
}
}

template <typename First, typename... Rest>
void log(const First& first, const Rest&... rest) {
if (isDebugDumpEnabled(DebugDumpOption::SchedulerVerbose)) {
std::cerr << first;
log(rest...);
}
}

} // namespace scheduler_debug_utils

} // namespace cuda
Expand Down
176 changes: 144 additions & 32 deletions third_party/nvfuser/csrc/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <scheduler/reduction.h>

#include <executor_utils.h>
#include <grouped_reduction.h>
#include <instrumentation.h>
#include <ir_all_nodes.h>
#include <ir_iostream.h>
#include <ir_utils.h>
#include <scheduler/normalization_utils.h>
#include <scheduler/reduction_utils.h>
#include <scheduler/registry.h>
#include <scheduler/utils.h>
Expand All @@ -13,6 +15,8 @@

#include <ATen/cuda/CUDAContext.h>

#include <cmath>

namespace torch {
namespace jit {
namespace fuser {
Expand Down Expand Up @@ -516,7 +520,7 @@ std::shared_ptr<ReductionParams> innerPersistentHeuristic(
if (godim > 1) {
rparams->grid_dim_iter_dom = ParallelType::BIDx;
if (godim > scheduler_utils::x_grid_limit) {
rparams->split_grid_dim_iter_dom = true;
rparams->split_grid_dim_iter_dom_outer = true;
gdimx = scheduler_utils::x_grid_limit;
}
}
Expand Down Expand Up @@ -566,6 +570,77 @@ std::shared_ptr<ReductionParams> innerPersistentHeuristic(
return rparams;
}

// Heuristics for grid outer normalizations
std::shared_ptr<ReductionParams> gridOuterPersistentHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const int64_t max_persistent_buffer_size,
const size_t vectorize_factor) {
auto outer_params =
normalization_scheduler_utils::getGridOuterNormalizationParams(
total_reduction_numel,
total_iteration_numel,
vectorize_factor,
max_persistent_buffer_size);

TORCH_INTERNAL_ASSERT(outer_params.has_value(), "No valid config found");

const auto pb_size = outer_params->persistent_buffer_factor;
const auto unswitch_factor = outer_params->unswitch_factor;

auto rparams = std::make_shared<ReductionParams>();

rparams->persistent_kernel = true;
rparams->cross_block_inner_reduction = true;
rparams->cross_grid_inner_reduction = true;
rparams->grid_dim_iter_dom = ParallelType::BIDx;
rparams->grid_dim_inner_reduction = ParallelType::BIDy;
rparams->block_dim_inner_reduction = ParallelType::TIDy;
rparams->batches_per_block_inner_reduction = pb_size;
rparams->multiple_reds_per_blk = true;
rparams->vectorize_iter_dom = true;
rparams->unroll_factor_iter_dom = vectorize_factor;
rparams->block_dim_iter_dom = ParallelType::TIDx;
rparams->unroll_factor_inner_reduction = unswitch_factor;
rparams->split_grid_dim_iter_dom_inner =
ceilDiv(
total_iteration_numel / vectorize_factor,
outer_params->launch_params.bdimx()) >
outer_params->launch_params.gdimx();
rparams->compute_persistent_buffer_with_first_consumer = true;
rparams->static_bdimx = true;
rparams->static_bdimy = true;

rparams->lparams = LaunchParams(
rparams->split_grid_dim_iter_dom_inner
? outer_params->launch_params.gdimx()
: LaunchParams::UNINITIALIZED_VAL,
LaunchParams::UNINITIALIZED_VAL,
LaunchParams::UNINITIALIZED_VAL,
outer_params->launch_params.bdimx(),
outer_params->launch_params.bdimy(),
LaunchParams::UNINITIALIZED_VAL);

if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== Reduction Stats ========\n"
<< "total_reduction_numel: " << total_reduction_numel << "\n"
<< "total_iteration_numel: " << total_iteration_numel << "\n"
<< "vectorize_factor: " << vectorize_factor << "\n"
<< "n_tensor_inputs: " << n_tensor_inputs << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "max_persistent_buffer_size: " << max_persistent_buffer_size
<< "\n"
<< "persistent_buffer_factor: " << pb_size << "\n"
<< "block(" << outer_params->launch_params.bdimx() << ", "
<< outer_params->launch_params.bdimy() << ", 1)" << std::endl;
std::cerr << rparams->toString() << std::endl;
}

return rparams;
}

// Copied from reduction scheduler, should generalize. Simply needed to take out
// grid reductions.
// TODO: Check adding iteration domain unrolling
Expand All @@ -587,13 +662,6 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
const int64_t device_multiprocessor_count =
(int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

auto const max_unroll = ceilDiv(
// Available unrolling based on size of data type
(int64_t)16 / (int64_t)max_input_dtype_size,
// Reduce unrolling if we have many inputs, start reduction at 4 inputs
scheduler_utils::lastPow2(
std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1)));

// If it fits in l2, we just want to make sure each warp uses 32Bytes. Set
// minimum warp as 16 threads instead of 32 as if we have a small reduction
// dim going a bit smaller than 32 usually helps.
Expand All @@ -602,16 +670,44 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
? (int64_t)32 / max_input_dtype_size
: 16;

// Initialization
const auto register_file_size =
at::cuda::getCurrentDeviceProperties()->regsPerBlock * sizeof(int);

// Each block runs N reductions, where N is defined as:
// vectorize_factor * blockDim.x. The minimum number of SMs to run
// this as a persistent kernel is thus defined as:
const int64_t min_required_sm_per_norm = ceilDiv(
max_persistent_buffer_size * vectorize_factor *
normalization_scheduler_utils::PreferredLaunchConfig::kMinBdimx,
register_file_size);

if (min_required_sm_per_norm > 1) {
return gridOuterPersistentHeuristic(
total_reduction_numel,
total_iteration_numel,
n_tensor_inputs,
max_input_dtype_size,
max_persistent_buffer_size,
vectorize_factor);
}

int64_t target_blocks = 1;
int64_t target_unroll = 1;
int64_t max_threads_in_block = warp_size;

// If we have one warp per block, check if that's enough to saturate the SMs.
// Blocks can't come out of reduction dimension, so only use iteration
// dimension here.
// If we have one warp per block, check if that's enough to saturate the
// SMs. Blocks can't come out of reduction dimension, so only use
// iteration dimension here.
target_blocks = ceilDiv(total_iteration_numel, (int64_t)warp_size);

const auto max_unroll = ceilDiv(
// Available unrolling based on size of data type
(int64_t)16 / (int64_t)max_input_dtype_size,
// Reduce unrolling if we have many inputs, start reduction at 4
// inputs
scheduler_utils::lastPow2(
scheduler_utils::safeDiv((int64_t)n_tensor_inputs, 4)));

// If we have more than a wave of blocks, put parallelism into unrolling
if (target_blocks > device_multiprocessor_count) {
target_unroll = std::min(
Expand All @@ -636,10 +732,8 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(

// Compute maximum number of reductions we could do in the same kernel based
// on persistent buffer size

const int64_t max_multi_reduction_factor = std::max(
scheduler_utils::register_file_size / max_persistent_buffer_size,
(int64_t)1);
const int64_t max_multi_reduction_factor = scheduler_utils::safeDiv(
scheduler_utils::register_file_size, max_persistent_buffer_size);

// To get to target threads:
// Prioritize
Expand All @@ -652,16 +746,11 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
// (2) y dim in multiple reductions - need to flip unrolling to reduction
// domain for this

// Blocks for outputs
// int64_t gdimx = 1; // unused at this time, comment for clang tidy

// Threads for reduction
int64_t bdimy = 1;
// Threads for output
int64_t bdimx = 1;

int64_t gdimx = 1;

// Unroll amount
int64_t inner_reduction_unroll_factor = 1;
int64_t iter_unroll_factor = 1;
Expand All @@ -673,12 +762,12 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
bdimx = scheduler_utils::lastPow2(bdimx);
}

// Prioritie unrolling on iteration domain, but don't sacrifice occupancy,
// Prioritize unrolling on iteration domain, but don't sacrifice occupancy,
// make sure there is at least one wave.
if (ceilDiv(total_iteration_numel, bdimx) > 2 * device_multiprocessor_count) {
iter_unroll_factor = std::min(
std::min(
std::max(max_multi_reduction_factor / bdimx, (int64_t)1),
scheduler_utils::safeDiv(max_multi_reduction_factor, bdimx),
max_unroll),
ceilDiv(device_multiprocessor_count, bdimx));
}
Expand All @@ -690,10 +779,10 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
// Put more into bdimx
bdimx = std::min(
std::min(
std::max(
scheduler_utils::safeDiv(
// Don't exceed multi reduction factor
max_multi_reduction_factor / iter_unroll_factor,
(int64_t)1),
max_multi_reduction_factor,
iter_unroll_factor),
// Leave a full wave of blocks
ceilDiv(
total_iteration_numel,
Expand All @@ -711,7 +800,7 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(

// Fill bdimy with left over threads
bdimy = std::min(
std::max(max_threads_in_block / bdimx, (int64_t)1),
scheduler_utils::safeDiv(max_threads_in_block, bdimx),
total_reduction_numel);

bool vectorize = false;
Expand All @@ -724,6 +813,15 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
(int64_t)vectorize_factor);
}

int64_t sm_required_per_norm_set = ceilDiv(
max_persistent_buffer_size * bdimx * iter_unroll_factor,
scheduler_utils::register_file_size);

TORCH_INTERNAL_ASSERT(
sm_required_per_norm_set == 1,
"Tried to use multiple SMs on an outer persistent kernel ",
"yet this kernel should have been within block persistent.");

// Since this is persistent and registers will have to be used anyways unroll
// the reduction dim if it's available
inner_reduction_unroll_factor =
Expand Down Expand Up @@ -751,7 +849,7 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
batches_per_block != roundUpPow2Or8(batches_per_block / 2)) {
batches_per_block = roundUpPow2Or8(batches_per_block / 2);

// Adjust bdimx based on batches_per_block and unroll factor set
// Adjust bdimy based on batches_per_block and unroll factor set
bdimy = ceilDiv(
total_reduction_numel,
inner_reduction_unroll_factor * batches_per_block);
Expand All @@ -775,7 +873,7 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
bdimx = ceilDiv(bdimx, 2);
}

gdimx = ceilDiv(total_iteration_numel, bdimx);
int gdimx = ceilDiv(total_iteration_numel, bdimx);

auto rparams = std::make_shared<ReductionParams>();
rparams->batches_per_block_inner_reduction = batches_per_block;
Expand All @@ -791,7 +889,8 @@ std::shared_ptr<ReductionParams> outerPersistentHeuristic(
}

rparams->grid_dim_iter_dom = ParallelType::BIDx;
rparams->split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit;
rparams->split_grid_dim_iter_dom_outer =
gdimx > scheduler_utils::x_grid_limit;

if (rparams->block_dim_iter_dom == ParallelType::TIDx) {
rparams->block_dim_inner_reduction = ParallelType::TIDy;
Expand Down Expand Up @@ -1052,8 +1151,6 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
// fusion segmentation
scheduler_utils::clearMemorySpace(fusion);

auto persistent_info = scheduler_utils::persistentBuffers(fusion);

auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);

TORCH_INTERNAL_ASSERT(reduction_tvs.size());
Expand All @@ -1072,6 +1169,11 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
scheduler_utils::domainReorderAsRfactorMap(reduction_tv));
}

if (rparams.persistent_kernel && rparams.cross_grid_inner_reduction &&
!rparams.fastest_dim && reduction_tvs.size() > 1) {
groupReductions(reduction_tvs, false);
}

auto dim_analysis = scheduler_utils::canonicalDimReduction(
fusion, reduction_tv, rparams.fastest_dim && rparams.schedule_3D);
bool has_iter_axis = dim_analysis.first;
Expand Down Expand Up @@ -1099,6 +1201,7 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
for (auto output : dummy_outputs) {
fusion->addOutput(output);
}

reduction_scheduler_utils::multiReductionInliner(
fusion,
rparams,
Expand All @@ -1108,6 +1211,15 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
cached_inputs,
cached_outputs,
dummy_outputs);

if (rparams.compute_persistent_buffer_with_first_consumer) {
TORCH_INTERNAL_ASSERT(
rparams.persistent_kernel,
"computeWith should be only used with persistent kernels");
for (const auto persistent_buffer : cached_inputs) {
persistent_buffer->computeWith(-1, true);
}
}
}

} // namespace cuda
Expand Down
Loading

0 comments on commit 7b37a83

Please sign in to comment.