Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combined inner outer reduction, add a simple test case #2400

Open
wants to merge 14 commits into
base: devel
Choose a base branch
from
28 changes: 19 additions & 9 deletions third_party/nvfuser/csrc/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1828,16 +1828,26 @@ void schedulePersistentKernelInnerOuter(
// directly from output tv using parallelizeAllLike. must propagate seperaely
// for different tvs as outer reductions are transformed seperately.
if (rparams.vectorization_factor_outer > 1) {
auto findVectorizedOutputOf = [&](TensorView* tv) {
TensorView* ref_tv = nullptr;
for (auto output_tv : ir_utils::outputTvsOf(tv)) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
for (auto id : output_tv->domain()->domain()) {
if (id->getParallelType() == ParallelType::Vectorize) {
ref_tv = output_tv;
break;
}
}
if (ref_tv) {
break;
}
}
return ref_tv;
};
for (auto tv : cached_gmem_reload) {
auto output_tvs = ir_utils::outputTvsOf(tv);
TORCH_INTERNAL_ASSERT(
!output_tvs.empty(),
"cached_gmem_reload should have at least one output tensor.")
scheduler_utils::parallelizeAllLike(
output_tvs[0],
-1,
{cached_gmem_reload.begin(), cached_gmem_reload.end()},
{ParallelType::Vectorize});
if (auto ref_tv = findVectorizedOutputOf(tv)) {
scheduler_utils::parallelizeAllLike(
ref_tv, -1, {tv}, {ParallelType::Vectorize});
}
}
}

Expand Down