Skip to content

Commit

Permalink
WIP: Fix #2559
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Mar 9, 2023
1 parent 3b85308 commit dc4b796
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions third_party/nvfuser/csrc/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
cached_outputs,
dummy_outputs);

normalization_scheduler_utils::fixUpInvalidPersistentBuffers(fusion);

if (rparams.compute_persistent_buffer_with_first_consumer) {
TORCH_INTERNAL_ASSERT(
rparams.persistent_kernel,
Expand Down
32 changes: 32 additions & 0 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include <fusion.h>
#include <inlining.h>
#include <ir_cloner.h>
#include <lower_trivial_broadcast.h>
#include <scheduler/debug_utils.h>
#include <scheduler/normalization_utils.h>
#include <scheduler/utils.h>
#include <utils.h>

#include <ATen/cuda/CUDAContext.h>
Expand Down Expand Up @@ -494,5 +499,32 @@ std::optional<GridOuterNormalizationParams> getGridOuterNormalizationParams(
return std::nullopt;
}

void fixUpInvalidPersistentBuffers(Fusion* fusion) {
auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion);
ConcretizedBroadcastDomains concretize_info(fusion);

for (auto persistent_buffer : persistent_buffer_info.persistent_buffers) {
std::cerr << "PB: " << persistent_buffer->toString() << std::endl;

for (auto axis : persistent_buffer->domain()->domain()) {
if (!concretize_info.isConcretized(axis) || !axis->isThread()) {
continue;
}
// Found
std::cerr << "Concretized broadcast in persistent buffer: "
<< axis->toString() << std::endl;
// Recompute
for (Expr* use : persistent_buffer->uses()) {
auto buffer_replicate = RecomputeTv::recompute(persistent_buffer);
ir_utils::replaceValInExpr(use, persistent_buffer, buffer_replicate);
std::cerr << "Replicated: " << buffer_replicate->toString()
<< std::endl;
}
}
}

inlineMost();
}

} // namespace normalization_scheduler_utils
} // namespace nvfuser
5 changes: 5 additions & 0 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <vector>

namespace nvfuser {

class Fusion;

namespace normalization_scheduler_utils {

//! Utility class to iterate candidates of launch configurations in a
Expand Down Expand Up @@ -145,5 +148,7 @@ std::optional<GridOuterNormalizationParams> getGridOuterNormalizationParams(
int64_t vectorize_factor,
int64_t persistent_buffer_size);

void fixUpInvalidPersistentBuffers(Fusion* fusion);

} // namespace normalization_scheduler_utils
} // namespace nvfuser

0 comments on commit dc4b796

Please sign in to comment.