diff --git a/third_party/nvfuser/csrc/codegen.cpp b/third_party/nvfuser/csrc/codegen.cpp index f5a789576cff5..99c8eb924bfd6 100644 --- a/third_party/nvfuser/csrc/codegen.cpp +++ b/third_party/nvfuser/csrc/codegen.cpp @@ -526,12 +526,12 @@ class CudaKernelGenerator : private OptOutConstDispatch { // Out of line predicate variant code_ << "<" << dtype << ", " << vec_size << ">(" << genInline(ldst->out()->as()->index()) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n"; + << genInline(ldst->in()->as()->index()) << ");\n"; } else { // Inline predicate variant code_ << "<" << dtype << ", " << vec_size << ">(" << genInline(ldst->out()->as()->index()) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << "," + << genInline(ldst->in()->as()->index()) << "," << genInline(ldst->predicate()) << ");\n"; } } @@ -548,11 +548,20 @@ class CudaKernelGenerator : private OptOutConstDispatch { << ");\n"; } - void handle(const kir::SMemAddress* sop) final { + void handle(const kir::BaseAddress* sop) final { if (!print_inline_) { indent() << gen(sop->output(0)) << " = "; } - code_ << "toSmem(" << ir_utils::varName(sop->smemTv()) << ")"; + switch (sop->tv()->getMemoryType()) { + case MemoryType::Shared: + code_ << "toSmem(" << ir_utils::varName(sop->tv()) << ")"; + break; + case MemoryType::Global: + code_ << ir_utils::varName(sop->tv()) << ".data"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "Unsupported input for kir::BaseAddress"); + } if (!print_inline_) { code_ << ";\n"; } diff --git a/third_party/nvfuser/csrc/dispatch.cpp b/third_party/nvfuser/csrc/dispatch.cpp index eb5a1653ad369..74548048b1fba 100644 --- a/third_party/nvfuser/csrc/dispatch.cpp +++ b/third_party/nvfuser/csrc/dispatch.cpp @@ -41,6 +41,10 @@ template void Val::dispatch(T handler, Val* val) { switch (*(val->getValType())) { case ValType::Scalar: + if (std::holds_alternative(val->getDataType()->type)) { + ptr(handler)->handle(val->as()); + return; + } switch (std::get(val->getDataType()->type)) { case DataType::Bool: ptr(handler)->handle(val->as()); @@ -273,8 +277,8 @@ void Expr::dispatch(T handler, Expr* expr) { ptr(handler)->handle(expr->as()); return; } - if (expr->isStrictlyA()) { - ptr(handler)->handle(expr->as()); + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); return; } TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -294,6 +298,10 @@ template void Val::constDispatch(T handler, const Val* val) { switch (*(val->getValType())) { case ValType::Scalar: + if (std::holds_alternative(val->getDataType()->type)) { + ptr(handler)->handle(val->as()); + return; + } switch (std::get(val->getDataType()->type)) { case DataType::Bool: ptr(handler)->handle(val->as()); @@ -530,8 +538,8 @@ void Expr::constDispatch(T handler, const Expr* expr) { ptr(handler)->handle(expr->as()); return; } - if (expr->isStrictlyA()) { - ptr(handler)->handle(expr->as()); + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); return; } TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); @@ -562,6 +570,10 @@ template void Val::mutatorDispatch(T mutator, Val* val) { switch (*(val->getValType())) { case ValType::Scalar: + if (std::holds_alternative(val->getDataType()->type)) { + ptr(mutator)->mutate(val->as()); + return; + } switch (std::get(val->getDataType()->type)) { case DataType::Bool: ptr(mutator)->mutate(val->as()); @@ -888,7 +900,7 @@ void OptOutConstDispatch::handle(const kir::VectorizedWelfordOp* stmt) { void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) { unhandled(stmt); } -void OptOutConstDispatch::handle(const kir::SMemAddress* stmt) { +void OptOutConstDispatch::handle(const kir::BaseAddress* stmt) { unhandled(stmt); } @@ -1062,7 +1074,7 @@ void OptOutDispatch::handle(kir::VectorizedWelfordOp* stmt) { void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) { unhandled(stmt); } -void OptOutDispatch::handle(kir::SMemAddress* stmt) { +void OptOutDispatch::handle(kir::BaseAddress* stmt) { unhandled(stmt); } diff --git a/third_party/nvfuser/csrc/dispatch.h b/third_party/nvfuser/csrc/dispatch.h index 55cc79b8563ff..c7eb4ba81536e 100644 --- a/third_party/nvfuser/csrc/dispatch.h +++ b/third_party/nvfuser/csrc/dispatch.h @@ -119,7 +119,7 @@ class VectorizedWelfordOp; class AllocateFusedReduction; class InitMagicZero; class UpdateMagicZero; -class SMemAddress; +class BaseAddress; } // namespace kir @@ -195,7 +195,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const kir::GroupedGridWelford*); virtual void handle(const kir::VectorizedWelfordOp*); virtual void handle(const kir::AllocateFusedReduction*); - virtual void handle(const kir::SMemAddress*); + virtual void handle(const kir::BaseAddress*); }; class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { @@ -268,7 +268,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(kir::GroupedGridWelford* stmt); virtual void handle(kir::VectorizedWelfordOp* stmt); virtual void handle(kir::AllocateFusedReduction* stmt); - virtual void handle(kir::SMemAddress* stmt); + virtual void handle(kir::BaseAddress* stmt); }; class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch { diff --git a/third_party/nvfuser/csrc/expr_evaluator.cpp b/third_party/nvfuser/csrc/expr_evaluator.cpp index 21dc74ed168c6..300fd2c38e422 100644 --- a/third_party/nvfuser/csrc/expr_evaluator.cpp +++ b/third_party/nvfuser/csrc/expr_evaluator.cpp @@ -104,7 +104,7 @@ c10::optional ExpressionEvaluator::evaluate(const Val* value) { if (!maybe_concrete_value.has_value()) { if (auto def = value->definition()) { FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); - if (def->isA()) { + if (def->isA()) { return c10::nullopt; } std::vector inputs; diff --git a/third_party/nvfuser/csrc/index_compute.cpp b/third_party/nvfuser/csrc/index_compute.cpp index 0f0ccbf57429c..c8f2c2bce6716 100644 --- a/third_party/nvfuser/csrc/index_compute.cpp +++ b/third_party/nvfuser/csrc/index_compute.cpp @@ -22,6 +22,8 @@ #include #include +#include + namespace nvfuser { namespace { @@ -312,6 +314,23 @@ Val* getProducerIndexWithPartialSplit( producer_index, SimplifyingIrBuilder::create(diff->evaluateInt())); } +Val* getTensorBaseAddress(TensorView* tv) { + Val* output = nullptr; + switch (auto memtype = tv->getMemoryType()) { + case MemoryType::Global: + output = IrBuilder::newScalar( + PointerOf{std::make_shared(*tv->getDataType())}); + break; + case MemoryType::Shared: + output = IrBuilder::newScalar(DataType::SMemAddress); + break; + default: + TORCH_CHECK(false, "Unsupported memory type ", memtype); + } + IrBuilder::create(output, tv); + return output; +} + } // namespace void IndexCompute::handle(Split* split) { @@ -2135,26 +2154,34 @@ Val* Index::getProducerStridedIndices( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index, - bool cvta_smem_address) { + bool generate_pointer) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices"); if (producer->domain()->noReductions().size() == 0) { - return GpuLower::current()->kernel()->zeroVal(); + if (generate_pointer) { + return getTensorBaseAddress(producer); + } else { + return GpuLower::current()->kernel()->zeroVal(); + } } if (producer->getMemoryType() == MemoryType::Global) { - return sumVals(getGlobalProducerStridedIndices( + auto index = sumVals(getGlobalProducerStridedIndices( producer, consumer, loops, rotated_loops, override_index)); + if (generate_pointer) { + return SimplifyingIrBuilder::addExpr( + getTensorBaseAddress(producer), index); + } else { + return index; + } } else { auto index = sumVals(getNonGlobalProducerStridedIndices( producer, consumer, loops, rotated_loops, override_index)); - if (cvta_smem_address && producer->getMemoryType() == MemoryType::Shared) { - auto base_address = IrBuilder::newScalar(DataType::SMemAddress); - IrBuilder::create(base_address, producer); + if (generate_pointer) { auto index_bytes = IrBuilder::mulExpr( index, IrBuilder::newConstant( dataTypeSize(*producer->getDataType()), *index->getDataType())); - return IrBuilder::addExpr(base_address, index_bytes); + return IrBuilder::addExpr(getTensorBaseAddress(producer), index_bytes); } else { return index; } @@ -2168,14 +2195,14 @@ kir::TensorIndex* Index::getProducerIndex( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index, - bool cvta_smem_address) { + bool generate_pointer) { auto index = getProducerStridedIndices( producer, consumer, loops, rotated_loops, override_index, - cvta_smem_address); + generate_pointer); index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops); return SimplifyingIrBuilder::create(producer, index); } @@ -2185,26 +2212,34 @@ Val* Index::getConsumerStridedIndices( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index, - bool cvta_smem_address) { + bool generate_pointer) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices"); if (consumer->domain()->noReductions().size() == 0) { - return GpuLower::current()->kernel()->zeroVal(); + if (generate_pointer) { + return getTensorBaseAddress(consumer); + } else { + return GpuLower::current()->kernel()->zeroVal(); + } } if (consumer->getMemoryType() == MemoryType::Global) { - return sumVals(getGlobalConsumerStridedIndices( + auto index = sumVals(getGlobalConsumerStridedIndices( consumer, loops, rotated_loops, override_index)); + if (generate_pointer) { + return SimplifyingIrBuilder::addExpr( + getTensorBaseAddress(consumer), index); + } else { + return index; + } } else { auto index = sumVals( getNonGlobalConsumerStridedIndices(consumer, loops, rotated_loops)); - if (cvta_smem_address && consumer->getMemoryType() == MemoryType::Shared) { - auto base_address = IrBuilder::newScalar(DataType::SMemAddress); - IrBuilder::create(base_address, consumer); + if (generate_pointer) { auto index_bytes = IrBuilder::mulExpr( index, IrBuilder::newConstant( dataTypeSize(*consumer->getDataType()), *index->getDataType())); - return IrBuilder::addExpr(base_address, index_bytes); + return IrBuilder::addExpr(getTensorBaseAddress(consumer), index_bytes); } else { return index; } @@ -2217,9 +2252,9 @@ kir::TensorIndex* Index::getConsumerIndex( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index, - bool cvta_smem_address) { + bool generate_pointer) { auto index = getConsumerStridedIndices( - consumer, loops, rotated_loops, override_index, cvta_smem_address); + consumer, loops, rotated_loops, override_index, generate_pointer); index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops); return SimplifyingIrBuilder::create(consumer, index); } diff --git a/third_party/nvfuser/csrc/index_compute.h b/third_party/nvfuser/csrc/index_compute.h index bb2b528b38f79..8f8efeef2ef67 100644 --- a/third_party/nvfuser/csrc/index_compute.h +++ b/third_party/nvfuser/csrc/index_compute.h @@ -346,18 +346,20 @@ class Index { // Consumer = Producer // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer // Producer indexing dispatch - // The argument `cvta_smem_address` specifies whether to use `cvta` ptx to - // convert shared memory address to unsigned int for indexing. This argument - // is effective only if the indexed tensor is a shared memory tensor. On other - // memory type, this argument will be ignored. Search `toSmem` in the codebase - // for additional information. + // The argument `generate_pointer` specifies whether to generate pointer for + // the tensor. If global tensor, then generate T1.data. If shared memory + // tensor, then use `cvta` ptx to convert shared memory address to unsigned + // int for indexing. Search `toSmem` in the codebase for additional + // information. This argument is effective only if the indexed tensor is a + // shared memory or global tensor. On other memory type, this argument will + // cause an error. static kir::TensorIndex* getProducerIndex( TensorView* producer, const TensorView* consumer, const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, - bool cvta_smem_address = false); + bool generate_pointer = false); // Consumer index dispatch static kir::TensorIndex* getConsumerIndex( @@ -365,7 +367,7 @@ class Index { const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, - bool cvta_smem_address = false); + bool generate_pointer = false); //! Returns a vector of strided indices mapped onto the (rfactor) //! root domain of a producer tensor. The size of the returned @@ -377,7 +379,7 @@ class Index { const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, - bool cvta_smem_address = false); + bool generate_pointer = false); //! Returns a vector of strided indices mapped onto the (rfactor) //! root domain of a consumer tensor. The size of the returned @@ -388,7 +390,7 @@ class Index { const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, - bool cvta_smem_address = false); + bool generate_pointer = false); //! Returns the logical index linearized from a multi-dimension address into a //! linear memory address a consumer tensor. The returned index is intended to diff --git a/third_party/nvfuser/csrc/ir_base_nodes.cpp b/third_party/nvfuser/csrc/ir_base_nodes.cpp index e2a1e1acc2961..f6729dc813e63 100644 --- a/third_party/nvfuser/csrc/ir_base_nodes.cpp +++ b/third_party/nvfuser/csrc/ir_base_nodes.cpp @@ -259,7 +259,8 @@ int64_t Val::evaluateInt() { auto evaluated_val = ee.evaluate(this); TORCH_INTERNAL_ASSERT( evaluated_val.has_value(), - "Detected a const integer but failed to infer its value."); + "Detected a const integer but failed to infer its value: ", + toInlineString()); return evaluated_val->as(); } diff --git a/third_party/nvfuser/csrc/ir_builder.cpp b/third_party/nvfuser/csrc/ir_builder.cpp index 4777b7aba93e4..f0df70210c642 100644 --- a/third_party/nvfuser/csrc/ir_builder.cpp +++ b/third_party/nvfuser/csrc/ir_builder.cpp @@ -11,6 +11,9 @@ namespace nvfuser { Val* IrBuilder::newScalar(DataType dtype) { + if (isPointerType(dtype)) { + return IrBuilder::create(dtype); + } switch (std::get(dtype.type)) { case DataType::Bool: return IrBuilder::create(); @@ -20,7 +23,6 @@ Val* IrBuilder::newScalar(DataType dtype) { case DataType::Int: case DataType::Int32: case DataType::Index: - case DataType::SMemAddress: return IrBuilder::create(dtype); case DataType::ComplexFloat: case DataType::ComplexDouble: @@ -49,11 +51,16 @@ Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { // than just allowing the integer type promotion for the two inputs as below. // Note that this is only needed for integer types. See also PR #2228. if (lhs->dtype() != rhs->dtype()) { - if (lhs->dtype() == DataType::SMemAddress || - rhs->dtype() == DataType::SMemAddress) { + if (isPointerType(lhs->dtype())) { + TORCH_INTERNAL_ASSERT(isIntegralType(rhs->dtype())); + TORCH_INTERNAL_ASSERT( + op_type == BinaryOpType::Add || op_type == BinaryOpType::Sub); + dtype = lhs->dtype(); + } else if (isPointerType(rhs->dtype())) { + TORCH_INTERNAL_ASSERT(isIntegralType(lhs->dtype())); TORCH_INTERNAL_ASSERT( op_type == BinaryOpType::Add || op_type == BinaryOpType::Sub); - dtype = DataType::SMemAddress; + dtype = rhs->dtype(); } else if ( (lhs->dtype() == DataType::Int && rhs->dtype() == DataType::Int32) || (lhs->dtype() == DataType::Int32 && rhs->dtype() == DataType::Int)) { diff --git a/third_party/nvfuser/csrc/ir_nodes.cpp b/third_party/nvfuser/csrc/ir_nodes.cpp index 3d88a24fd1643..f8b13c2b23ea8 100644 --- a/third_party/nvfuser/csrc/ir_nodes.cpp +++ b/third_party/nvfuser/csrc/ir_nodes.cpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace nvfuser { diff --git a/third_party/nvfuser/csrc/kernel_ir.cpp b/third_party/nvfuser/csrc/kernel_ir.cpp index 63a7f8ef90dcf..cba8f10b4bf52 100644 --- a/third_party/nvfuser/csrc/kernel_ir.cpp +++ b/third_party/nvfuser/csrc/kernel_ir.cpp @@ -328,31 +328,28 @@ std::string UpdateMagicZero::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(UpdateMagicZero) -SMemAddress::SMemAddress( - IrBuilderPasskey passkey, - Val* out, - TensorView* smem_tv) +BaseAddress::BaseAddress(IrBuilderPasskey passkey, Val* out, TensorView* tv) : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); addOutput(out); - addInput(smem_tv); + addInput(tv); } -std::string SMemAddress::toString(int indent_size) const { +std::string BaseAddress::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << "toSmem(" << ir_utils::varName(smemTv()) << ")\n"; + indent(ss, indent_size) << "BaseAddress(" << ir_utils::varName(tv()) << ")\n"; return ss.str(); } -std::string SMemAddress::toInlineString(int indent_size) const { +std::string BaseAddress::toInlineString(int indent_size) const { std::stringstream ss; - ss << "toSmem(" << ir_utils::varName(smemTv()) << ")"; + ss << "BaseAddress(" << ir_utils::varName(tv()) << ")"; return ss.str(); } -NVFUSER_DEFINE_CLONE_AND_CREATE(SMemAddress) +NVFUSER_DEFINE_CLONE_AND_CREATE(BaseAddress) std::string Scope::toString(int indent_size) const { std::stringstream ss; diff --git a/third_party/nvfuser/csrc/kernel_ir.h b/third_party/nvfuser/csrc/kernel_ir.h index d05ba72d603ac..a7c41f1e12fc1 100644 --- a/third_party/nvfuser/csrc/kernel_ir.h +++ b/third_party/nvfuser/csrc/kernel_ir.h @@ -360,24 +360,24 @@ class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { std::string toInlineString(int indent_size = 0) const override; }; -class TORCH_CUDA_CU_API SMemAddress final : public Expr { +class TORCH_CUDA_CU_API BaseAddress final : public Expr { public: using Expr::Expr; - explicit SMemAddress(IrBuilderPasskey passkey, Val* out, TensorView* smem_tv); + explicit BaseAddress(IrBuilderPasskey passkey, Val* out, TensorView* tv); NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { - return "SMemAddress"; + return "BaseAddress"; } - TensorView* smemTv() const { + TensorView* tv() const { return input(0)->as(); } bool sameAs(const Statement* other) const override { - auto other_saddr = dynamic_cast(other); + auto other_saddr = dynamic_cast(other); if (other_saddr == nullptr) { return false; } @@ -386,7 +386,7 @@ class TORCH_CUDA_CU_API SMemAddress final : public Expr { // T1_s = set(T0) // T2_s = set(T0) // Then T1_s and T2_s has different address although T1_s->sameAs(T2_s) - return other_saddr->smemTv() == smemTv(); + return other_saddr->tv() == tv(); } std::string toString(int indent_size = 0) const override; diff --git a/third_party/nvfuser/csrc/lower_index.cpp b/third_party/nvfuser/csrc/lower_index.cpp index f11580e7f3de3..608ff6841b8d8 100644 --- a/third_party/nvfuser/csrc/lower_index.cpp +++ b/third_party/nvfuser/csrc/lower_index.cpp @@ -14,7 +14,7 @@ Val* IndexLowering::lowerSrcIndex( Val* src, Val* dst, const std::unordered_map& override_index, - bool cvta_smem_address) const { + bool generate_pointer) const { if (auto tv = dynamic_cast(src)) { TORCH_INTERNAL_ASSERT(dst->isA()); return Index::getProducerIndex( @@ -23,7 +23,7 @@ Val* IndexLowering::lowerSrcIndex( for_loops_, getRotatedLoop(), override_index, - cvta_smem_address); + generate_pointer); } else { return src; } @@ -32,10 +32,10 @@ Val* IndexLowering::lowerSrcIndex( Val* IndexLowering::lowerDstIndex( Val* dst, const std::unordered_map& override_index, - bool cvta_smem_address) const { + bool generate_pointer) const { if (auto tv = dynamic_cast(dst)) { return Index::getConsumerIndex( - tv, for_loops_, getRotatedLoop(), override_index, cvta_smem_address); + tv, for_loops_, getRotatedLoop(), override_index, generate_pointer); } else { return dst; } @@ -1215,8 +1215,11 @@ void IndexLowering::handleGroupedGridWelford( } void IndexLowering::handle(const LoadStoreOp* ldst) { + // Today, LoadStoreOp can only be ld.matrix and cp.async. In the future, when + // we start to work on hopper support, this can also be TMA operations. const auto in = lowerSrcIndex(ldst->in(), ldst->out(), {}, true); - const auto out = lowerDstIndex(ldst->out(), {}, true); + const auto out = + lowerDstIndex(ldst->out(), {}, !ir_utils::isLdMatrixOp(ldst)); auto new_ldst = IrBuilder::create(ldst->opType(), out, in) ->withPredicate(ldst->predicate()); pushBack(new_ldst); diff --git a/third_party/nvfuser/csrc/lower_index.h b/third_party/nvfuser/csrc/lower_index.h index eeddbcc22d33c..d0293a3842852 100644 --- a/third_party/nvfuser/csrc/lower_index.h +++ b/third_party/nvfuser/csrc/lower_index.h @@ -77,21 +77,23 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { // This is can used to manually set the index for the given rFactor ID. // Currently, this `override_index` is only used by indexing ops like // select/index_select. - // The argument `cvta_smem_address` specifies whether to use `cvta` ptx to - // convert shared memory address to unsigned int for indexing. This argument - // is effective only if the indexed tensor is a shared memory tensor. On other - // memory type, this argument will be ignored. Search `toSmem` in the codebase - // for additional information. + // The argument `generate_pointer` specifies whether to generate pointer for + // the tensor. If global tensor, then generate T1.data. If shared memory + // tensor, then use `cvta` ptx to convert shared memory address to unsigned + // int for indexing. Search `toSmem` in the codebase for additional + // information. This argument is effective only if the indexed tensor is a + // shared memory or global tensor. On other memory type, this argument will + // cause an error. Val* lowerSrcIndex( Val* val, Val* dst, const std::unordered_map& override_index = {}, - bool cvta_smem_address = false) const; + bool generate_pointer = false) const; Val* lowerDstIndex( Val* dst, const std::unordered_map& override_index = {}, - bool cvta_smem_address = false) const; + bool generate_pointer = false) const; void handleBlockReduction(const ReductionOp* rop, Val* out, Val* in); void handleGridReduction(const ReductionOp* rop, Val* out, Val* in); diff --git a/third_party/nvfuser/csrc/lower_scalar_hoist.cpp b/third_party/nvfuser/csrc/lower_scalar_hoist.cpp index 2e33d4e9e1e8d..121a9ca2577ee 100644 --- a/third_party/nvfuser/csrc/lower_scalar_hoist.cpp +++ b/third_party/nvfuser/csrc/lower_scalar_hoist.cpp @@ -402,7 +402,7 @@ class CommonIndexInserter : private kir::ExprMutator { // but this seems to be the quickest way to use the value type // as we don't have a scalar IR node for the value type. auto dtype = *value->getDataType(); - if (isIntegralType(dtype) && dtype != DataType::SMemAddress) { + if (isIntegralType(dtype) && !isPointerType(dtype)) { value->resolveIndexDtype(); } diff --git a/third_party/nvfuser/csrc/type.cpp b/third_party/nvfuser/csrc/type.cpp index 89ad5937bc1ab..b8592ca7f184e 100644 --- a/third_party/nvfuser/csrc/type.cpp +++ b/third_party/nvfuser/csrc/type.cpp @@ -59,6 +59,11 @@ bool isIntegralType(DataType dtype) { dtype.type); } +bool isPointerType(DataType dtype) { + return std::holds_alternative(dtype.type) || + dtype == DataType::SMemAddress; +} + bool isComplexType(DataType dtype) { TORCH_CHECK( dtype != DataType::Null, @@ -1063,6 +1068,9 @@ std::string stringifyThread(const ParallelType ptype) { } std::string typePrefix(const DataType data_type) { + if (std::holds_alternative(data_type.type)) { + return "ptr"; + } switch (std::get(data_type.type)) { case DataType::Bool: return "b"; diff --git a/third_party/nvfuser/csrc/type.h b/third_party/nvfuser/csrc/type.h index a70560a98c2d7..9f2e3efea373d 100644 --- a/third_party/nvfuser/csrc/type.h +++ b/third_party/nvfuser/csrc/type.h @@ -138,9 +138,11 @@ DataType indexModeToDtype(KernelIndexMode index_mode); // Returns if the datatype is a floating point type TORCH_CUDA_CU_API bool isFloatingPointType(DataType dtype); -// Returns if the datatype is an boolean type -TORCH_CUDA_CU_API bool isIntegralType(DataType dtype); // Returns if the datatype is an integer type +TORCH_CUDA_CU_API bool isIntegralType(DataType dtype); +// Returns if the datatype is a pointer type +TORCH_CUDA_CU_API bool isPointerType(DataType dtype); +// Returns if the datatype is an boolean type TORCH_CUDA_CU_API bool isBooleanType(DataType dtype); // Returns if the datatype is a complex type TORCH_CUDA_CU_API bool isComplexType(DataType dtype); diff --git a/third_party/nvfuser/test/test_gpu_loop_rotation.cpp b/third_party/nvfuser/test/test_gpu_loop_rotation.cpp index 81194936c6c7a..2fa1c5a0a6b99 100644 --- a/third_party/nvfuser/test/test_gpu_loop_rotation.cpp +++ b/third_party/nvfuser/test/test_gpu_loop_rotation.cpp @@ -605,8 +605,8 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { } // This is a case similar to matmul, where we have -// tv1 = set(tv0) // cp.async for matmul -// tv2 = set(tv1) // ld.matrix for matmul +// tv4 = set(tv0) // cp.async for matmul +// tv1 = set(tv4) // ld.matrix for matmul // and both are double buffered TEST_F(NVFuserTest, FusionLoopRotationMultipleDoubleBuffer_CUDA) { // Please see note [Limitation of boundary assert] @@ -620,110 +620,101 @@ TEST_F(NVFuserTest, FusionLoopRotationMultipleDoubleBuffer_CUDA) { auto tv1 = set(tv0); auto tv2 = set(tv1); auto tv3 = set(tv2); - auto tv4 = set(tv3); - fusion.addOutput(tv4); + fusion.addOutput(tv3); - tv1->setMemoryType(MemoryType::Shared); + auto tv4 = tv0->cacheAfter(LoadStoreOpType::CpAsyncCa); + tv4->setMemoryType(MemoryType::Shared); - inlineAllAt(tv4, 1); - inlineSelectedAt({tv2, tv3, tv4}, tv4, 2); + inlineAllAt(tv3, 1); + inlineSelectedAt({tv1, tv2, tv3}, tv3, 2); - tv1->circularBuffer(5); - tv2->doubleBuffer(); - scheduler_utils::rotateLoop(tv4, 0, {tv2}); + tv4->circularBuffer(5); + tv1->doubleBuffer(); + scheduler_utils::rotateLoop(tv3, 0, {tv1}); const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { +__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T3) { alignas(16) extern __shared__ char array[]; unsigned smem_offset = 0; NVFUSER_DEFINE_MAGIC_ZERO - int64_t i116; - i116 = T0.stride[0] * 4; - int64_t i275; - i275 = -T0.size[0]; + float* ptr44; + ptr44 = T0.data; + float* ptr114; + ptr114 = ptr44 + (T0.stride[0] * 4); + int64_t i295; + i295 = -T0.size[0]; smem_offset = alignBufferSize(smem_offset, 16); - float* T1 = reinterpret_cast(array + smem_offset); + float* T4 = reinterpret_cast(array + smem_offset); smem_offset += (15 * sizeof(float)); #pragma unroll - for(nvfuser_index_t i22 = 0; i22 < 4; ++i22) { - int64_t i47; - i47 = 3 * i22; - int64_t i58; - i58 = T0.stride[0] * i22; - bool b254; - b254 = 0 < (T0.size[0] - (i22 + nvfuser_zero)); + for(nvfuser_index_t i18 = 0; i18 < 4; ++i18) { + float* ptr51; + ptr51 = ptr44 + (T0.stride[0] * i18); + unsigned i77; + i77 = (toSmem(T4)) + (12 * i18); + bool b274; + b274 = 0 < (T0.size[0] - (i18 + nvfuser_zero)); #pragma unroll - for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { - T1[(i47 + i21)] = 0; - } - #pragma unroll - for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { - int64_t i49; - i49 = i21 + nvfuser_zero; - if ((b254 && (i49 < 3))) { - T1[(i47 + i21)] - = T0[(i58 + (T0.stride[1] * i49))]; - } + for(nvfuser_index_t i17 = 0; i17 < 3; ++i17) { + int64_t i38; + i38 = i17 + nvfuser_zero; + Ampere::cpAsyncCa((i77 + (4 * i17)),(ptr51 + (T0.stride[1] * i38)),(b274 && (i38 < 3))); } + Ampere::cpAsyncCommit(); } NVFUSER_UPDATE_MAGIC_ZERO - float T2[2]; - T2[0] - = T1[0]; + Ampere::cpAsyncPartialBarrier<3>(); + float T1[2]; + T1[0] + = T4[0]; #pragma unroll 1 - for(nvfuser_index_t i23 = 0; i23 < T0.size[0]; ++i23) { - int64_t i96; - i96 = 3 * ((4 + i23) % 5); - int64_t i118; - i118 = i116 + (T0.stride[0] * i23); - int64_t i161; - i161 = 1 + (3 * (i23 % 5)); - int64_t i197; - i197 = 3 * i23; - bool b283; - b283 = (i275 + i23) < -4; - bool b307; - b307 = 0 < (T0.size[0] - i23); - #pragma unroll - for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { - T1[(i96 + i21)] = 0; - } - NVFUSER_UPDATE_MAGIC_ZERO + for(nvfuser_index_t i19 = 0; i19 < T0.size[0]; ++i19) { + float* ptr115; + ptr115 = ptr114 + (T0.stride[0] * i19); + unsigned i161; + i161 = (toSmem(T4)) + (12 * ((4 + i19) % 5)); + int64_t i181; + i181 = 1 + (3 * (i19 % 5)); + int64_t i217; + i217 = 3 * i19; + bool b303; + b303 = (i295 + i19) < -4; + bool b327; + b327 = 0 < (T0.size[0] - i19); + Ampere::cpAsyncPartialBarrier<3>(); #pragma unroll - for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { - int64_t i98; - i98 = i21 + nvfuser_zero; - if ((b283 && (i98 < 3))) { - T1[(i96 + i21)] - = T0[(i118 + (T0.stride[1] * i98))]; - } + for(nvfuser_index_t i17 = 0; i17 < 3; ++i17) { + int64_t i87; + i87 = i17 + nvfuser_zero; + Ampere::cpAsyncCa((i161 + (4 * i17)),(ptr115 + (T0.stride[1] * i87)),(b303 && (i87 < 3))); } NVFUSER_UPDATE_MAGIC_ZERO + Ampere::cpAsyncCommit(); #pragma unroll - for(nvfuser_index_t i26 = 0; i26 < 2; ++i26) { - int64_t i189; - i189 = i26 + nvfuser_zero; - T2[((1 + i26) % 2)] - = T1[(i161 + i26)]; - float T3[1]; - T3[0] - = T2[(i26 % 2)]; - if ((b307 && (i189 < 3))) { - T4[(i197 + i189)] - = T3[0]; + for(nvfuser_index_t i22 = 0; i22 < 2; ++i22) { + int64_t i209; + i209 = i22 + nvfuser_zero; + T1[((1 + i22) % 2)] + = T4[(i181 + i22)]; + float T2[1]; + T2[0] + = T1[(i22 % 2)]; + if ((b327 && (i209 < 3))) { + T3[(i217 + i209)] + = T2[0]; } } NVFUSER_UPDATE_MAGIC_ZERO - float T3[1]; - T3[0] - = T2[0]; - if (b307) { - T4[(2 + i197)] - = T3[0]; + float T2[1]; + T2[0] + = T1[0]; + if (b327) { + T3[(2 + i217)] + = T2[0]; } NVFUSER_UPDATE_MAGIC_ZERO - T2[0] - = T1[(3 * ((1 + i23) % 5))]; + T1[0] + = T4[(3 * ((1 + i19) % 5))]; } } )";