-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fusion_debug dump option with Val log #2326
base: devel
Are you sure you want to change the base?
Conversation
This enables setting the print format, including printing Fusions to arbitrary ostreams. Currently, there are two options for `SerializationFormat`: Default and Debug. The environment variable `PYTHON_NVFUSER_PRINT_FORMAT` is checked whenever a Fusion is printed (not if you print other individual `Statement`s). Currently the information printed is somewhat sparse, but I plan to add more detailed information to help tracking things like the mismatch in variable name indices between manual and automatically scheduled tests.
NameOnly is meant to become a shorter printout, but currently not all Statements have names that I can print. This commit introduces a sorted `fusion_debug` dump option that enables diffing.
I think we may want this to be a separate dump option since it is a pretty concise look at what happened without the noise that comes with fusion_debug. I'm not currently sure what utility fusion_debug has other than this log, but that may be because I haven't had to dive into the internals of the Vals and Exprs yet for a real problem.
The latter is to be used when not logging to `this`. I made these macros so that we could avoid computation of arguments to log() in Release builds, but currently that check is disabled since I'm not sure we'd want to force a full recompile of pytorch in order to dump this info.
This PR was motivated by cases like the following. Consider this basic fusion: pytorch/third_party/nvfuser/test/test_gpu_match_frontend.cpp Lines 197 to 205 in a06224f
Using manual and automatic scheduling give identical fusion_ir printouts, but looking at the generated kernels, they differ slightly:These kernels are obviously equivalent, but there are some differences possibly in the ordering of operations. When we dump fusion_debug and diff we see the following operation log (green is automatic and red is manual schedule):Now we see a lot of differences. There are other differences in the detailed dump that precedes the op log, indicating that some intermediate tensors (that are not shown at all in the fusion math or transforms view) may be parallelized differently, e.g. |
I currently print a warning in `Fusion::printDebug` when running a Release build.
83d6c84
to
18b9a4a
Compare
With the printout above of the Fusion operation log, I was able to exactly match the automatic scheduler exactly for this problem, and learned a few things about how the automatic scheduler does things. Here is what my previous manual schedule looked like: // Perform manual scheduling
tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
tv2->split(1, 1);
tv2->reorder({{-1, -2}, {-2, -1}});
auto tv3 = tv2->rFactor({1, 3});
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(2)->parallelize(ParallelType::TIDx);
tv3->axis(3)->parallelize(ParallelType::Unswitch);
// propagate the mapping to other tensors
TransformPropagatorWithCheck propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv3, {tv0, tv1, tv2});
tv1->computeAt(tv3, -1);
inlineMost(); And here is the one that matches the auto scheduler: // Perform manual scheduling
tv2->reorder({{1, 0}});
tv2->reorder({{1, 0}});
tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
tv2->axis(2)->parallelize(ParallelType::TIDx);
tv2->split(1, 1);
tv2->axis(2)->parallelize(ParallelType::Unswitch);
tv2->axis(0)->parallelize(ParallelType::BIDx);
// tv2->reorder({{-2, -1}}) has same effect but this shows the mapping explicitly
tv2->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 2}});
auto tv3 = tv2->rFactor({1, 3});
// propagate the mapping to other tensors
TransformPropagatorWithCheck propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv3, {}, allParallelTypesExcept(
{ParallelType::Unroll,
ParallelType::Vectorize,
ParallelType::MisalignedVectorize}));
inlineMost(); Note that these both give equal One odd thing is the double reorder at the very beginning of the automatic schedule. This doesn't effect the |
To make these tests pass including matching the generated kernels, I used the WIP fusion_debug dump output from #2326. This let me see a noisy log of operations that revealed a few patterns I was unaware of. For example, in the FrontendAdd test, the pointwise scheduler is used, and the order of inlining at the end is not defined (and does change) due to using unordered_set. Also, notice that getPointwiseHeuristics actually creates a Val (fusion.oneVal()) so it actually does technically alter the Fusion since it adds a Val to its vals_ list. Beyond those trivial things, I noticed differences in inlining patterns between the pointwise and reduction schedulers, and also saw some different ways to call methods like parallelizeAllLike. One thing I haven't looked more into is the reorders that often happen at the beginning of scheduling both pointwise and reduction fusions. They have no effect on the generated kernels, but it is noticeable and I plan to read up on it soon.
Again I used the fusion_debug dump from #2326 to trace what the reduction scheduler is doing. This time I learned about multiReductionInliner, which uses two calls to parallelizeAllLike for different types of ParallelTypes, followed by an undoing of unrolling and vectorization on the reference tensor. The need for the latter is still a little unclear to me.
To make these tests pass including matching the generated kernels, I used the WIP fusion_debug dump output from #2326. This let me see a noisy log of operations that revealed a few patterns I was unaware of. For example, in the FrontendAdd test, the pointwise scheduler is used, and the order of inlining at the end is not defined (and does change) due to using unordered_set. Also, notice that getPointwiseHeuristics actually creates a Val (fusion.oneVal()) so it actually does technically alter the Fusion since it adds a Val to its vals_ list. Beyond those trivial things, I noticed differences in inlining patterns between the pointwise and reduction schedulers, and also saw some different ways to call methods like parallelizeAllLike. One thing I haven't looked more into is the reorders that often happen at the beginning of scheduling both pointwise and reduction fusions. They have no effect on the generated kernels, but it is noticeable and I plan to read up on it soon.
Again I used the fusion_debug dump from #2326 to trace what the reduction scheduler is doing. This time I learned about multiReductionInliner, which uses two calls to parallelizeAllLike for different types of ParallelTypes, followed by an undoing of unrolling and vectorization on the reference tensor. The need for the latter is still a little unclear to me.
FUSER_PERF_SCOPE("Fusion::printDebug"); | ||
|
||
out << "Fusion DEBUG INFO {"; | ||
std::vector<Val*> inputs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This vector aliases the fusion's inputs_
vector so it always shows that the fusion doesn't have any inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we just need to delete this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks pretty good. Would be really useful as it would save repeated ad-hoc printf insertions!
My comments are mostly just some suggestions to consider if we could reduce repeated code by using some of the existing utility routines.
@@ -118,6 +120,32 @@ TORCH_CUDA_CU_API bool isOptionEnabled(EnableOption option); | |||
TORCH_CUDA_CU_API const std::vector<std::string>& getEnableOptionArguments( | |||
EnableOption option); | |||
|
|||
TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a duplicate of another declaration at line 77?
Looks like the below getDebugDumpArguments
is also a duplicate.
//! Write a message to this object's log | ||
void log(std::string op_name, std::vector<std::string> arg_strings) { | ||
#ifdef NDEBUG | ||
static bool warned = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's TORCH_WARN
and TORCH_WARN_ONCE
. Not very familiar with them, but you might want to use either of them.
@@ -361,6 +365,37 @@ class TORCH_CUDA_CU_API Val : public Statement { | |||
|
|||
void resolveIndexDtype(); | |||
|
|||
//! Get vector of log messages for this object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd prefer to have the type explicitly stated as std::vector<...>
than having a comment saying it's a vector.
} | ||
|
||
//! Write a message to this object's log | ||
void log(std::string op_name, std::vector<std::string> arg_strings) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move the def to ir_base_nodes.cpp?
} | ||
|
||
//! Write a message to this object's log | ||
void log(std::string op_name, std::vector<std::string> arg_strings) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What operation does op_name
refer to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what this means now after reading the other parts of the code. Would be helpful to have some (short) comments about the parameters.
@@ -62,6 +62,7 @@ class TORCH_CUDA_CU_API Scalar : public Val { | |||
(c10::is_complex<UnderlyingType>::value && isComplexType(dtype)), | |||
"Invalid data type: ", | |||
dtype); | |||
VAL_LOG("Scalar::Scalar", "IrBuilderPasskey", typePrefix(dtype)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd drop IrBuilderPasskey
as it's not really important to be printed out
"IterDomain::split", | ||
factor->toString(0, SerializationFormat::NameOnly), | ||
std::to_string(inner_split), | ||
std::to_string(trim_out_of_bounds), ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's Printer
utility class in utils.h. Could it help reduce repetition? For example, could VAL_LOG_EXPLICIT
just take whatever objects and use Printer
to covert to a string?
@@ -2576,6 +2702,13 @@ void TensorDomain::reorder(const std::unordered_map<int, int>& old2new_) { | |||
TORCH_INTERNAL_ASSERT( | |||
!(nDims() == 0 && old2new_.size() > 0), | |||
"Tried to reorder a 0-dim domain"); | |||
#ifndef NDEBUG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this guarded? Isn't VAL_LOG
a no-op when NDEBUG
is defined?
bool comma = false; | ||
for (auto t : selected_tvs) { | ||
if (comma) | ||
ss << ", "; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's also toDelimitedString
in utils.h, which you may be able to just use or extend.
FUSER_PERF_SCOPE("Fusion::printDebug"); | ||
|
||
out << "Fusion DEBUG INFO {"; | ||
std::vector<Val*> inputs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we just need to delete this line.
This PR adds a new dump option called
fusion_debug
which prints all theVal
s andExpr
s in aFusion
duringcompileFusion
, and also prints an ordered list of operations that were done on thoseVal
s. For example:Notice that the names of
Val
s in this log are all abbreviated: they're only the names and not other stuff that's typically printed intoString()
. To accomplish that,toString
now takes afmt
argument whose typeSerializationFormat
can be one ofDefault
(current behavior),NameOnly
, orDebug
. TheDebug
option prints slightly more thanDefault
: in the case of aTensorView
it printscontiguity_
for example. TheSerializationFormat
enum can be extended in the future to includes stuff like JSON, but I haven't explored that much yet.