diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 8e7ec987cbca0..e6b69f9c02387 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -311,8 +311,12 @@ void scheduleMatmul( // Setup accumulator register. auto cc = c->cacheBefore(); - // Setup output smem buffer - auto c_smem = c->cacheBefore(); + TensorView* c_smem = nullptr; + + if (params.has_epilog) { + // Setup output smem buffer + c_smem = c->cacheBefore(); + } // Get the input to the mma op. auto mma = dynamic_cast(cc->definition()); @@ -483,7 +487,10 @@ void scheduleMatmul( // Set memory type: acw_smem->setMemoryType(MemoryType::Shared); bcw_smem->setMemoryType(MemoryType::Shared); - c_smem->setMemoryType(MemoryType::Shared); + + if (params.has_epilog) { + c_smem->setMemoryType(MemoryType::Shared); + } // Set parallelization: // TODO: this section goes to a separate matmul util, @@ -534,37 +541,44 @@ void scheduleMatmul( bcr->skewDoubleBuffer(); } + auto output_buffer = params.has_epilog ? c_smem : c; + scheduler_utils::BoundedDirectionalTransformPropagator::forward( cc, -1, - {c_smem}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType() - .propagateToBoundary()); - - // Epilog schedule: - scheduler_utils::BoundedDirectionalTransformPropagator::forward( - c_smem, - 3, - {c}, + {output_buffer}, scheduler_utils::BoundedDirectionalTransformPropagator::Options() .propagateParallelType() .propagateToBoundary()); - c_smem->computeAt(c, 3); - c->reorder({{-1, -2}, {-2, -1}}); - // 16 x 128, with half of the warps: - - // Output vectorize by 4: - c->split(-2, 2); - c->split(-1, 4); - - // [8, 2, 32, 4] - c->axis(-3)->parallelize(ParallelType::TIDy); - c->axis(-2)->parallelize(ParallelType::TIDx); - c->axis(-1)->parallelize(ParallelType::Vectorize); - c_smem->axis(-1)->parallelize(ParallelType::Vectorize); - c_smem->doubleBuffer(); + // Epilog schedule (To be built out): + if (params.has_epilog) { + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + c_smem, + 3, + {c}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + + c_smem->computeAt(c, 3); + c->reorder({{-1, -2}, {-2, -1}}); + // 16 x 128, with half of the warps: + + // Output vectorize by 4: + c->split(-2, 2); + c->split(-1, 4); + + // [8, 2, 32, 4] + c->axis(-3)->parallelize(ParallelType::TIDy); + c->axis(-2)->parallelize(ParallelType::TIDx); + c->axis(-1)->parallelize(ParallelType::Vectorize); + c_smem->axis(-1)->parallelize(ParallelType::Vectorize); + c_smem->doubleBuffer(); + } else { + // Always vector + c->axis(-1)->parallelize(ParallelType::Vectorize); + } if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h index 354e2affeab04..e105b2e7e02cf 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h @@ -51,6 +51,9 @@ class MatmulParam { //! Enables predicate peeling mainloop: bool peel_main_loop = true; + + //! Enables an epilog schedule + bool has_epilog = false; }; //! Prototype auto scheduling function.