Skip to content

Commit

Permalink
Fix Frontend{Add,SuperBasic,SuperBasicFP16} tests
Browse files Browse the repository at this point in the history
To make these tests pass including matching the generated kernels, I
used the WIP fusion_debug dump output from
#2326. This let me see a noisy
log of operations that revealed a few patterns I was unaware of. For
example, in the FrontendAdd test, the pointwise scheduler is used, and
the order of inlining at the end is not defined (and does change) due to
using unordered_set. Also, notice that getPointwiseHeuristics actually
creates a Val (fusion.oneVal()) so it actually does technically alter
the Fusion since it adds a Val to its vals_ list.

Beyond those trivial things, I noticed differences in inlining patterns
between the pointwise and reduction schedulers, and also saw some
different ways to call methods like parallelizeAllLike. One thing I
haven't looked more into is the reorders that often happen at the
beginning of scheduling both pointwise and reduction fusions. They have
no effect on the generated kernels, but it is noticeable and I plan to
read up on it soon.
  • Loading branch information
jacobhinkle committed Jan 18, 2023
1 parent 35b872d commit bb1d5ab
Showing 1 changed file with 157 additions and 69 deletions.
226 changes: 157 additions & 69 deletions third_party/nvfuser/test/test_gpu_match_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,29 @@ TEST_F(NVFuserTest, FusionFrontendAdd_CUDA) {

std::vector<IValue> inputs = {t0, t1};

// Define fusion
// Define fusion for automatic scheduling
Fusion fauto;
{
FusionGuard fg(&fauto);

auto tv0 = makeSymbolicTensor(3);
auto tv1 = makeSymbolicTensor(3);

fauto.addInput(tv0);
fauto.addInput(tv1);

auto tv2 = add(tv0, tv1);
// auto tv4 = sum(tv2, {-1}, false, DataType::Float);

fauto.addOutput(tv2);

// Run automatic scheduler
auto pointwise_params = getPointwiseHeuristics(&fauto, inputs);
TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!");
schedulePointwise(&fauto, *pointwise_params);
}

// Repeat definition of fusion for manual scheduling
Fusion fusion;
FusionGuard fg(&fusion);

Expand All @@ -123,37 +145,59 @@ TEST_F(NVFuserTest, FusionFrontendAdd_CUDA) {

fusion.addOutput(tv2);

// Run automatic scheduler
auto fauto = Fusion(fusion); // unique_ptr to copy of fusion
auto pointwise_params = getPointwiseHeuristics(&fauto, inputs);
TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!");
schedulePointwise(&fauto, *pointwise_params);

// Perform manual scheduling
auto tv0p = tv0->cacheAfter();
auto tv1p = tv1->cacheAfter();

// Before schedulePointwise() is called, getPointwiseHeuristics() calls
// vectorize_helper::getExpandedVectorization() which in turn calls:
// vectorize_helper::getVectorizationSize
// vectorize_helper::ProjectedExtent::getNumerator
// vectorize_helper::ProjectedExtent::computeNumerDenomir
// IrContainer::oneVal
// oneVal() creates an actual Val here to hold the denominator and
// initializes it to 1. Since this is reflected in the fusion log, I'm
// inserting it here even though it has not effect on the generated kernel.
fusion.oneVal();

// scheduler_utils::cacheInputs(fusion, true);
tv0->cacheAfter(); // tv3
tv1->cacheAfter(); // tv4

// scheduler_utils::cacheAndForkOutputs(fusion, true);
auto tv5 = tv2->cacheBefore(); // tv5

tv2->merge(1, 2);
tv2->merge(0, 1);
tv2->reorder({{0, -1}});
tv2->reorder({{-1, 0}});
tv2->split(0, 128);
tv2->split(0, 1);
tv2->split(0, 1);
auto tv2l = tv2->cacheBefore();
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::Unswitch);
tv2->axis(3)->parallelize(ParallelType::TIDx);

inlineMost();
tv0p->computeAt(tv2, 2);
tv1p->computeAt(tv2, 2);
// inlineMost();
// tv3->computeAt(tv2, 2);
// tv4->computeAt(tv2, 2);

TransformPropagatorWithCheck propagator(tv2);
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv2, {tv0, tv1, tv0p, tv1p, tv2l});
scheduler_utils::parallelizeAllLike(tv2);

// Pointwise scheduler does not use inlineMost(), as reduction scheduler does
// inlineMost();
// Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors)
inlineAllAt(tv2, 2, true);
inlineMost(std::vector<TensorView*>({tv5, tv1, tv0}));

compare_ir(fusion, fauto);
// Note that inlineAllAt iterates through an unordered_set to do inlining, so
// it is not practical to match the fusion_debug log exactly when using
// pointwise scheduler
compare_ir_math(fusion, fauto);
compare_transforms(fusion, fauto);
// compare_fusion_debug(fusion, fauto);
compare_kernels(fusion, fauto);

// compare_ir(fusion, fauto);

// Perform eager computation and verify
auto t2 = t0.add(t1);
Expand Down Expand Up @@ -190,7 +234,29 @@ TEST_F(NVFuserTest, FusionFrontendSuperBasic_CUDA) {

std::vector<IValue> inputs = {t0};

// Define fusion
Fusion fauto;
{ // Do automatic scheduling on fauto
FusionGuard fg(&fauto);

auto tv0 = makeSymbolicTensor(2); // {i0, i1}
auto c0 = IrBuilder::create<Double>(3.0);

fauto.addInput(tv0);

auto tv1 = mul(tv0, c0); // {i0, i1}
auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1}

fauto.addOutput(tv2);

// Run automatic scheduler
auto reduction_params = getReductionHeuristics(&fauto, inputs);
TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
scheduleReduction(&fauto, *reduction_params);
}

// Re-define the fusion exactly for manual scheduling
// This is necessary in order to catch all the constructors inside each
// Fusion independently.
Fusion fusion;
FusionGuard fg(&fusion);

Expand All @@ -204,34 +270,36 @@ TEST_F(NVFuserTest, FusionFrontendSuperBasic_CUDA) {

fusion.addOutput(tv2);

// Run automatic scheduler
auto fauto = Fusion(fusion); // unique_ptr to copy of fusion
auto reduction_params = getReductionHeuristics(&fauto, inputs);
TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
scheduleReduction(&fauto, *reduction_params);

// Perform manual scheduling

tv2->reorder({{1, 0}}); // Removing these two reorders does not effect the
// generated kernel
tv2->reorder({{1, 0}});
tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
tv2->axis(2)->parallelize(ParallelType::TIDx);
tv2->split(1, 1);
tv2->reorder({{-1, -2}, {-2, -1}});
tv2->axis(2)->parallelize(ParallelType::Unswitch);
tv2->axis(0)->parallelize(ParallelType::BIDx);

auto tv3 = tv2->rFactor({1, 3});
// tv2->reorder({{-2, -1}}) has same effect but this shows the mapping
// explicitly
tv2->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 2}});

tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(2)->parallelize(ParallelType::TIDx);
tv3->axis(3)->parallelize(ParallelType::Unswitch);
auto tv3 = tv2->rFactor({1, 3});

// propagate the mapping to other tensors
TransformPropagatorWithCheck propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv3, {tv0, tv1, tv2});

tv1->computeAt(tv3, -1);
scheduler_utils::parallelizeAllLike(
tv3,
{},
allParallelTypesExcept(
{ParallelType::Unroll,
ParallelType::Vectorize,
ParallelType::MisalignedVectorize}));

inlineMost();

// CUDA kernel is equivalent, but automatic scheduling uses i23 instead of
// i22 for the name of the index variable in the loop (rFactor, see tv3)
compare_ir(fusion, fauto);

// Perform eager computation and verify
Expand Down Expand Up @@ -271,7 +339,30 @@ TEST_F(NVFuserTest, FusionFrontendSuperBasicFP16_CUDA) {

std::vector<IValue> inputs = {t0};

// Define fusion
Fusion fauto;
{ // Do automatic scheduling on fauto
FusionGuard fg(&fauto);

auto tv0 = makeSymbolicTensor(2, DataType::Half); // {i0, i1}
auto c0 = IrBuilder::create<Double>(3.0);

fauto.addInput(tv0);

auto tv1 = mul(tv0, c0); // {i0, i1}
auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1}
auto tv3 = castOp(DataType::Half, tv2);

fauto.addOutput(tv3);

// Run automatic scheduler
auto reduction_params = getReductionHeuristics(&fauto, inputs);
TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
scheduleReduction(&fauto, *reduction_params);
}

// Re-define the fusion exactly for manual scheduling
// This is necessary in order to catch all the constructors inside each
// Fusion independently.
Fusion fusion;
FusionGuard fg(&fusion);

Expand All @@ -280,50 +371,46 @@ TEST_F(NVFuserTest, FusionFrontendSuperBasicFP16_CUDA) {

fusion.addInput(tv0);

// Note: A manual schedule will run without an explicit cast here, producing
// an _implicit_ cast to float. That float tensor will not be explicitly
// parallelized to match, but will result
auto tv0float = castOp(DataType::Float, tv0);

auto tv1 = mul(tv0float, c0); // {i0, i1}
auto tv2float = sum(tv1, {-1}, false, DataType::Float); // {i0, r1}
auto tv2 = castOp(DataType::Half, tv2float);

fusion.addOutput(tv2);
auto tv1 = mul(tv0, c0); // {i0, i1}
auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1}
auto tv4 = castOp(DataType::Half, tv2);

// Run automatic scheduler
auto fauto = Fusion(fusion); // unique_ptr to copy of fusion
auto reduction_params = getReductionHeuristics(&fauto, inputs);
TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
scheduleReduction(&fauto, *reduction_params);
fusion.addOutput(tv4);

// Perform manual scheduling
tv2float->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
tv2float->split(1, 1);
tv2float->reorder({{-1, -2}, {-2, -1}});
tv2->reorder({{1, 0}}); // Removing these two reorders does not effect the
// generated kernel
tv2->reorder({{1, 0}});
tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
tv2->axis(2)->parallelize(ParallelType::TIDx);
tv2->split(1, 1);
tv2->axis(2)->parallelize(ParallelType::Unswitch);
tv2->axis(0)->parallelize(ParallelType::BIDx);

auto tv3 = tv2float->rFactor({1, 3});
// tv2->reorder({{-2, -1}}) has same effect but this shows the mapping
// explicitly
tv2->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 2}});

tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(2)->parallelize(ParallelType::TIDx);
tv3->axis(3)->parallelize(ParallelType::Unswitch);
auto tv3 = tv2->rFactor({1, 3});

// propagate the mapping to other tensors
TransformPropagatorWithCheck propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv3, {tv0, tv0float, tv1, tv2, tv2float});

tv1->computeAt(tv3, -1);
scheduler_utils::parallelizeAllLike(
tv3,
{},
allParallelTypesExcept(
{ParallelType::Unroll,
ParallelType::Vectorize,
ParallelType::MisalignedVectorize}));

inlineMost();

// CUDA kernel is equivalent, but automatic scheduling uses i31 instead of
// i30 for the name of the index variable in the loop (rFactor, see tv3)
compare_ir(fusion, fauto);

// Perform eager computation and verify
auto t1 = t0 * 3.0;
auto t2 = t1.sum({-1}, false, c10::kFloat);
auto t2 = t1.sum({-1}, false);

int runtime_threadIdx_dim = 128;
LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
Expand Down Expand Up @@ -384,9 +471,6 @@ TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) {
scheduleReduction(&fauto, *reduction_params);

// Perform manual scheduling
//
// {i0, i1, i2}

tv4->merge(0, 1); // {i0*i1, i2}
tv4->split(
1,
Expand All @@ -403,12 +487,16 @@ TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) {
tv4->axis(3)->parallelize(ParallelType::Unroll);
tv4->axis(4)->parallelize(ParallelType::TIDx);
tv4->axis(5)->parallelize(ParallelType::Unswitch);
auto tv5 = tv4->rFactor({1, 5});

auto tv5 = tv0->cacheAfter();
auto tv6 = tv1->cacheAfter();
auto tv7 = tv4->cacheBefore();
auto tv8 = tv7->rFactor({1, 5});

// propagate the mapping to other tensors
TransformPropagatorWithCheck propagator(tv5);
MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv5, {tv0, tv1, tv2, tv3, tv4});
TransformPropagatorWithCheck propagator(tv7);
MaxRootDomainInfoSpanningTree(tv7).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv7, {tv2, tv3, tv4, tv5, tv6, tv8});

inlineMost();

Expand Down

0 comments on commit bb1d5ab

Please sign in to comment.