Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combined inner outer reduction, add a simple test case #2400

Open
wants to merge 14 commits into
base: devel
Choose a base branch
from
133 changes: 53 additions & 80 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,100 +565,73 @@ bool hasSharedInput(
return has_shared_input;
}

std::unordered_set<TensorView*> getAllProducerTvsOf(
const std::unordered_set<TensorView*>& tv_set,
bool skip_reduction_tv) {
std::unordered_set<TensorView*> all_producer_tvs;
std::queue<TensorView*> unvisited;

// start search from tv_set, save its direct producers
for (auto tv0 : tv_set) {
for (auto tv : ir_utils::producerTvsOf(tv0)) {
bool skip_tv = skip_reduction_tv && tv->hasReduction();
if (!skip_tv && all_producer_tvs.find(tv) == all_producer_tvs.end()) {
all_producer_tvs.emplace(tv);
unvisited.push(tv);
}
std::unordered_set<TensorView*> getAllTvsFrom(
const std::vector<TensorView*>& from_tvs,
const std::unordered_set<TensorView*>& cutoff_tv_set,
const bool cut_off_reduction_tv) {
std::unordered_set<TensorView*> tv_group;
std::queue<TensorView*> tensors_to_visit;
auto addIfNotVisited = [&](TensorView* tv) {
if (tv_group.find(tv) == tv_group.end() &&
cutoff_tv_set.find(tv) == cutoff_tv_set.end()) {
tv_group.emplace(tv);
tensors_to_visit.push(tv);
}
};

for (auto tv : from_tvs) {
tensors_to_visit.push(tv);
}
// search for indirect producers from tv0's direct producers
while (!unvisited.empty()) {
auto next_tv = unvisited.front();
unvisited.pop();
for (auto tv : ir_utils::producerTvsOf(next_tv)) {
bool skip_tv = skip_reduction_tv && tv->hasReduction();
if (!skip_tv && all_producer_tvs.find(tv) == all_producer_tvs.end()) {
all_producer_tvs.emplace(tv);
unvisited.push(tv);
while (!tensors_to_visit.empty()) {
auto next_tv = tensors_to_visit.front();
tensors_to_visit.pop();
// visit consumers
for (auto tv : ir_utils::consumerTvsOf(next_tv)) {
addIfNotVisited(tv);
}
// visit siblings
for (auto tv : ir_utils::siblingTvsOf(next_tv)) {
addIfNotVisited(tv);
}
// don't visit producer of reduction tv if cut_off_reduction_tv is true
if (!next_tv->hasReduction() || !cut_off_reduction_tv) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
for (auto tv : ir_utils::producerTvsOf(next_tv)) {
addIfNotVisited(tv);
}
}
}
return all_producer_tvs;
return tv_group;
}

bool hasSharedConsumerNonOuterReductionProducer(
bool isConnectedOnlyThroughReductionProducer(
const std::vector<TensorView*>& inner_reduction_tvs,
const std::vector<TensorView*>& outer_reduction_tvs) {
// step-1, check if inner reduction and outer reduction tvs have shared
// consumer. get all consumers of the inner_reduction_tvs shared consumers
// between inner reduction tvs are allowed.
const auto& all_vals = DependencyCheck::getAllDependentVals(
{inner_reduction_tvs.begin(), inner_reduction_tvs.end()});
const auto& all_tvs = ir_utils::filterByType<TensorView>(all_vals);
std::unordered_set<TensorView*> all_consumer_tvs_inner{
all_tvs.begin(), all_tvs.end()};

// check if outer reduction tvs have any shared consumer with inner reduction
// tvs and other outer reduction tvs
// Tested in FusionCombinedSchedulerSharedConsumer_CUDA
std::unordered_set<TensorView*> all_consumer_tvs_outer;
std::unordered_set<TensorView*> producer_of_consumer_tvs_outer;
for (const auto otv : outer_reduction_tvs) {
const auto& out_vals = DependencyCheck::getAllDependentVals({otv});
const auto& consumers = ir_utils::filterByType<TensorView>(out_vals);
// detect links through consumers
for (const auto tv : consumers) {
if (all_consumer_tvs_inner.find(tv) != all_consumer_tvs_inner.end()) {
return true;
}
if (all_consumer_tvs_outer.find(tv) != all_consumer_tvs_outer.end()) {
return true;
const std::unordered_set<TensorView*> outer_tv_set{
outer_reduction_tvs.begin(), outer_reduction_tvs.end()};
const auto& disjoint_inner_reduction_tvs =
getAllTvsFrom(inner_reduction_tvs, outer_tv_set, false);

std::unordered_set<TensorView*> disjoint_outer_reduction_tvs;
for (auto otv : outer_reduction_tvs) {
const auto& connected_tv_set = getAllTvsFrom({otv}, {}, true);
naoyam marked this conversation as resolved.
Show resolved Hide resolved
for (auto tv : connected_tv_set) {
// case-1, outer reduction tv can't be connected with other outer
// reduction tvs except through their reduction producers
if (disjoint_outer_reduction_tvs.find(tv) ==
disjoint_outer_reduction_tvs.end()) {
disjoint_outer_reduction_tvs.emplace(tv);
} else {
all_consumer_tvs_outer.emplace(tv);
return false;
}
naoyam marked this conversation as resolved.
Show resolved Hide resolved
}
// detect links through consumers' producers
const auto& producers =
getAllProducerTvsOf({consumers.begin(), consumers.end()}, true);
for (const auto tv : producers) {
// check shared producer of outer reduction tvs' consumers
if (producer_of_consumer_tvs_outer.find(tv) !=
producer_of_consumer_tvs_outer.end()) {
return true;
} else {
producer_of_consumer_tvs_outer.emplace(tv);
// case-2, outer reduction tv can't be connected with inner reduction tvs
// except through their reduction producers
if (disjoint_inner_reduction_tvs.find(tv) !=
naoyam marked this conversation as resolved.
Show resolved Hide resolved
disjoint_inner_reduction_tvs.end()) {
return false;
}
}
}

// step-2, check if consumers of inner reduction and outer reduction tvs have
// shared producer. Shared produer is only allowed if it is used in outer
// reduction since only the first part of outer reduction is computed with
// inner reduction. Tested in FusionCombinedSchedulerSharedProducer_CUDA
all_consumer_tvs_inner.insert(
inner_reduction_tvs.begin(), inner_reduction_tvs.end());
const auto& all_producer_tvs_outer =
getAllProducerTvsOf(all_consumer_tvs_outer, true);
const auto& all_producer_tvs_inner =
getAllProducerTvsOf(all_consumer_tvs_inner, false);
for (auto outer : all_producer_tvs_outer) {
if (all_producer_tvs_inner.find(outer) != all_producer_tvs_inner.end()) {
return true;
}
}

// don't have shared consumer or non-outer-reduction producer
return false;
return true;
}

int64_t partialReductionBufferSize(
Expand Down
11 changes: 6 additions & 5 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,12 @@ bool hasSharedInput(
const std::vector<TensorView*>& inner_reduction_tvs,
const std::vector<TensorView*>& outer_reduction_tvs);

//! check if outer reduction tvs have any shared consumer with inner reduction
//! tvs and other outer reduction tvs.
//! check if consumer of inner reduction tvs and outer reduction tvs have
//! shared non-outer-reduction producer
bool hasSharedConsumerNonOuterReductionProducer(
//! The first part of outer reduction is computed with inner reduction and the
//! second part is scheduled separately. So, (1) the outer reduction tvs can
//! only be connected with inner reduction tvs through their producers. (2)
//! Outer reduction tvs are also scheduled separately and they can only be
//! connected through their producers.
bool isConnectedOnlyThroughReductionProducer(
const std::vector<TensorView*>& inner_reduction_tvs,
const std::vector<TensorView*>& outer_reduction_tvs);

Expand Down
4 changes: 2 additions & 2 deletions third_party/nvfuser/csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1888,8 +1888,8 @@ class PersistentKernelScheduler : public SchedulerEntry {
return false;
}

if (normalization_scheduler_utils::
hasSharedConsumerNonOuterReductionProducer(
if (!normalization_scheduler_utils::
isConnectedOnlyThroughReductionProducer(
inner_reduction_tvs, outer_reduction_tvs)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Persistent,
Expand Down