Skip to content

Commit

Permalink
cp.async access global tensor via pointer (#2282)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Mar 6, 2023
1 parent b928665 commit 16a26a1
Show file tree
Hide file tree
Showing 17 changed files with 228 additions and 158 deletions.
17 changes: 13 additions & 4 deletions third_party/nvfuser/csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,12 +526,12 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// Out of line predicate variant
code_ << "<" << dtype << ", " << vec_size << ">("
<< genInline(ldst->out()->as<kir::TensorIndex>()->index()) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
<< genInline(ldst->in()->as<kir::TensorIndex>()->index()) << ");\n";
} else {
// Inline predicate variant
code_ << "<" << dtype << ", " << vec_size << ">("
<< genInline(ldst->out()->as<kir::TensorIndex>()->index()) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ","
<< genInline(ldst->in()->as<kir::TensorIndex>()->index()) << ","
<< genInline(ldst->predicate()) << ");\n";
}
}
Expand All @@ -548,11 +548,20 @@ class CudaKernelGenerator : private OptOutConstDispatch {
<< ");\n";
}

void handle(const kir::SMemAddress* sop) final {
void handle(const kir::BaseAddress* sop) final {
if (!print_inline_) {
indent() << gen(sop->output(0)) << " = ";
}
code_ << "toSmem(" << ir_utils::varName(sop->smemTv()) << ")";
switch (sop->tv()->getMemoryType()) {
case MemoryType::Shared:
code_ << "toSmem(" << ir_utils::varName(sop->tv()) << ")";
break;
case MemoryType::Global:
code_ << ir_utils::varName(sop->tv()) << ".data";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unsupported input for kir::BaseAddress");
}
if (!print_inline_) {
code_ << ";\n";
}
Expand Down
24 changes: 18 additions & 6 deletions third_party/nvfuser/csrc/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ template <typename T>
void Val::dispatch(T handler, Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
if (std::holds_alternative<PointerOf>(val->getDataType()->type)) {
ptr(handler)->handle(val->as<Int>());
return;
}
switch (std::get<PrimDataType>(val->getDataType()->type)) {
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
Expand Down Expand Up @@ -273,8 +277,8 @@ void Expr::dispatch(T handler, Expr* expr) {
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
}
if (expr->isStrictlyA<kir::SMemAddress>()) {
ptr(handler)->handle(expr->as<kir::SMemAddress>());
if (expr->isStrictlyA<kir::BaseAddress>()) {
ptr(handler)->handle(expr->as<kir::BaseAddress>());
return;
}
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
Expand All @@ -294,6 +298,10 @@ template <typename T>
void Val::constDispatch(T handler, const Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
if (std::holds_alternative<PointerOf>(val->getDataType()->type)) {
ptr(handler)->handle(val->as<Int>());
return;
}
switch (std::get<PrimDataType>(val->getDataType()->type)) {
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
Expand Down Expand Up @@ -530,8 +538,8 @@ void Expr::constDispatch(T handler, const Expr* expr) {
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
}
if (expr->isStrictlyA<kir::SMemAddress>()) {
ptr(handler)->handle(expr->as<kir::SMemAddress>());
if (expr->isStrictlyA<kir::BaseAddress>()) {
ptr(handler)->handle(expr->as<kir::BaseAddress>());
return;
}
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
Expand Down Expand Up @@ -562,6 +570,10 @@ template <typename T>
void Val::mutatorDispatch(T mutator, Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
if (std::holds_alternative<PointerOf>(val->getDataType()->type)) {
ptr(mutator)->mutate(val->as<Int>());
return;
}
switch (std::get<PrimDataType>(val->getDataType()->type)) {
case DataType::Bool:
ptr(mutator)->mutate(val->as<Bool>());
Expand Down Expand Up @@ -888,7 +900,7 @@ void OptOutConstDispatch::handle(const kir::VectorizedWelfordOp* stmt) {
void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::SMemAddress* stmt) {
void OptOutConstDispatch::handle(const kir::BaseAddress* stmt) {
unhandled(stmt);
}

Expand Down Expand Up @@ -1062,7 +1074,7 @@ void OptOutDispatch::handle(kir::VectorizedWelfordOp* stmt) {
void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::SMemAddress* stmt) {
void OptOutDispatch::handle(kir::BaseAddress* stmt) {
unhandled(stmt);
}

Expand Down
6 changes: 3 additions & 3 deletions third_party/nvfuser/csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class VectorizedWelfordOp;
class AllocateFusedReduction;
class InitMagicZero;
class UpdateMagicZero;
class SMemAddress;
class BaseAddress;

} // namespace kir

Expand Down Expand Up @@ -195,7 +195,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::GroupedGridWelford*);
virtual void handle(const kir::VectorizedWelfordOp*);
virtual void handle(const kir::AllocateFusedReduction*);
virtual void handle(const kir::SMemAddress*);
virtual void handle(const kir::BaseAddress*);
};

class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
Expand Down Expand Up @@ -268,7 +268,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::GroupedGridWelford* stmt);
virtual void handle(kir::VectorizedWelfordOp* stmt);
virtual void handle(kir::AllocateFusedReduction* stmt);
virtual void handle(kir::SMemAddress* stmt);
virtual void handle(kir::BaseAddress* stmt);
};

class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch {
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ c10::optional<EvaluatorValue> ExpressionEvaluator::evaluate(const Val* value) {
if (!maybe_concrete_value.has_value()) {
if (auto def = value->definition()) {
FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate");
if (def->isA<kir::SMemAddress>()) {
if (def->isA<kir::BaseAddress>()) {
return c10::nullopt;
}
std::vector<EvaluatorValue> inputs;
Expand Down
71 changes: 53 additions & 18 deletions third_party/nvfuser/csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <transform_iter.h>
#include <transform_replay.h>

#include <memory>

namespace nvfuser {

namespace {
Expand Down Expand Up @@ -312,6 +314,23 @@ Val* getProducerIndexWithPartialSplit(
producer_index, SimplifyingIrBuilder::create<Int>(diff->evaluateInt()));
}

Val* getTensorBaseAddress(TensorView* tv) {
Val* output = nullptr;
switch (auto memtype = tv->getMemoryType()) {
case MemoryType::Global:
output = IrBuilder::newScalar(
PointerOf{std::make_shared<DataType>(*tv->getDataType())});
break;
case MemoryType::Shared:
output = IrBuilder::newScalar(DataType::SMemAddress);
break;
default:
TORCH_CHECK(false, "Unsupported memory type ", memtype);
}
IrBuilder::create<kir::BaseAddress>(output, tv);
return output;
}

} // namespace

void IndexCompute::handle(Split* split) {
Expand Down Expand Up @@ -2135,26 +2154,34 @@ Val* Index::getProducerStridedIndices(
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool cvta_smem_address) {
bool generate_pointer) {
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
if (producer->domain()->noReductions().size() == 0) {
return GpuLower::current()->kernel()->zeroVal();
if (generate_pointer) {
return getTensorBaseAddress(producer);
} else {
return GpuLower::current()->kernel()->zeroVal();
}
}

if (producer->getMemoryType() == MemoryType::Global) {
return sumVals(getGlobalProducerStridedIndices(
auto index = sumVals(getGlobalProducerStridedIndices(
producer, consumer, loops, rotated_loops, override_index));
if (generate_pointer) {
return SimplifyingIrBuilder::addExpr(
getTensorBaseAddress(producer), index);
} else {
return index;
}
} else {
auto index = sumVals(getNonGlobalProducerStridedIndices(
producer, consumer, loops, rotated_loops, override_index));
if (cvta_smem_address && producer->getMemoryType() == MemoryType::Shared) {
auto base_address = IrBuilder::newScalar(DataType::SMemAddress);
IrBuilder::create<kir::SMemAddress>(base_address, producer);
if (generate_pointer) {
auto index_bytes = IrBuilder::mulExpr(
index,
IrBuilder::newConstant(
dataTypeSize(*producer->getDataType()), *index->getDataType()));
return IrBuilder::addExpr(base_address, index_bytes);
return IrBuilder::addExpr(getTensorBaseAddress(producer), index_bytes);
} else {
return index;
}
Expand All @@ -2168,14 +2195,14 @@ kir::TensorIndex* Index::getProducerIndex(
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool cvta_smem_address) {
bool generate_pointer) {
auto index = getProducerStridedIndices(
producer,
consumer,
loops,
rotated_loops,
override_index,
cvta_smem_address);
generate_pointer);
index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops);
return SimplifyingIrBuilder::create<kir::TensorIndex>(producer, index);
}
Expand All @@ -2185,26 +2212,34 @@ Val* Index::getConsumerStridedIndices(
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<int, Val*>& override_index,
bool cvta_smem_address) {
bool generate_pointer) {
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices");
if (consumer->domain()->noReductions().size() == 0) {
return GpuLower::current()->kernel()->zeroVal();
if (generate_pointer) {
return getTensorBaseAddress(consumer);
} else {
return GpuLower::current()->kernel()->zeroVal();
}
}

if (consumer->getMemoryType() == MemoryType::Global) {
return sumVals(getGlobalConsumerStridedIndices(
auto index = sumVals(getGlobalConsumerStridedIndices(
consumer, loops, rotated_loops, override_index));
if (generate_pointer) {
return SimplifyingIrBuilder::addExpr(
getTensorBaseAddress(consumer), index);
} else {
return index;
}
} else {
auto index = sumVals(
getNonGlobalConsumerStridedIndices(consumer, loops, rotated_loops));
if (cvta_smem_address && consumer->getMemoryType() == MemoryType::Shared) {
auto base_address = IrBuilder::newScalar(DataType::SMemAddress);
IrBuilder::create<kir::SMemAddress>(base_address, consumer);
if (generate_pointer) {
auto index_bytes = IrBuilder::mulExpr(
index,
IrBuilder::newConstant(
dataTypeSize(*consumer->getDataType()), *index->getDataType()));
return IrBuilder::addExpr(base_address, index_bytes);
return IrBuilder::addExpr(getTensorBaseAddress(consumer), index_bytes);
} else {
return index;
}
Expand All @@ -2217,9 +2252,9 @@ kir::TensorIndex* Index::getConsumerIndex(
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<int, Val*>& override_index,
bool cvta_smem_address) {
bool generate_pointer) {
auto index = getConsumerStridedIndices(
consumer, loops, rotated_loops, override_index, cvta_smem_address);
consumer, loops, rotated_loops, override_index, generate_pointer);
index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops);
return SimplifyingIrBuilder::create<kir::TensorIndex>(consumer, index);
}
Expand Down
20 changes: 11 additions & 9 deletions third_party/nvfuser/csrc/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,26 +346,28 @@ class Index {
// Consumer = Producer
// i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
// Producer indexing dispatch
// The argument `cvta_smem_address` specifies whether to use `cvta` ptx to
// convert shared memory address to unsigned int for indexing. This argument
// is effective only if the indexed tensor is a shared memory tensor. On other
// memory type, this argument will be ignored. Search `toSmem` in the codebase
// for additional information.
// The argument `generate_pointer` specifies whether to generate pointer for
// the tensor. If global tensor, then generate T1.data. If shared memory
// tensor, then use `cvta` ptx to convert shared memory address to unsigned
// int for indexing. Search `toSmem` in the codebase for additional
// information. This argument is effective only if the indexed tensor is a
// shared memory or global tensor. On other memory type, this argument will
// cause an error.
static kir::TensorIndex* getProducerIndex(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool cvta_smem_address = false);
bool generate_pointer = false);

// Consumer index dispatch
static kir::TensorIndex* getConsumerIndex(
TensorView* consumer,
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<int, Val*>& override_index = {},
bool cvta_smem_address = false);
bool generate_pointer = false);

//! Returns a vector of strided indices mapped onto the (rfactor)
//! root domain of a producer tensor. The size of the returned
Expand All @@ -377,7 +379,7 @@ class Index {
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool cvta_smem_address = false);
bool generate_pointer = false);

//! Returns a vector of strided indices mapped onto the (rfactor)
//! root domain of a consumer tensor. The size of the returned
Expand All @@ -388,7 +390,7 @@ class Index {
const std::vector<kir::ForLoop*>& loops,
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<int, Val*>& override_index = {},
bool cvta_smem_address = false);
bool generate_pointer = false);

//! Returns the logical index linearized from a multi-dimension address into a
//! linear memory address a consumer tensor. The returned index is intended to
Expand Down
3 changes: 2 additions & 1 deletion third_party/nvfuser/csrc/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ int64_t Val::evaluateInt() {
auto evaluated_val = ee.evaluate(this);
TORCH_INTERNAL_ASSERT(
evaluated_val.has_value(),
"Detected a const integer but failed to infer its value.");
"Detected a const integer but failed to infer its value: ",
toInlineString());
return evaluated_val->as<int64_t>();
}

Expand Down
15 changes: 11 additions & 4 deletions third_party/nvfuser/csrc/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
namespace nvfuser {

Val* IrBuilder::newScalar(DataType dtype) {
if (isPointerType(dtype)) {
return IrBuilder::create<Int>(dtype);
}
switch (std::get<PrimDataType>(dtype.type)) {
case DataType::Bool:
return IrBuilder::create<Bool>();
Expand All @@ -20,7 +23,6 @@ Val* IrBuilder::newScalar(DataType dtype) {
case DataType::Int:
case DataType::Int32:
case DataType::Index:
case DataType::SMemAddress:
return IrBuilder::create<Int>(dtype);
case DataType::ComplexFloat:
case DataType::ComplexDouble:
Expand Down Expand Up @@ -49,11 +51,16 @@ Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
// than just allowing the integer type promotion for the two inputs as below.
// Note that this is only needed for integer types. See also PR #2228.
if (lhs->dtype() != rhs->dtype()) {
if (lhs->dtype() == DataType::SMemAddress ||
rhs->dtype() == DataType::SMemAddress) {
if (isPointerType(lhs->dtype())) {
TORCH_INTERNAL_ASSERT(isIntegralType(rhs->dtype()));
TORCH_INTERNAL_ASSERT(
op_type == BinaryOpType::Add || op_type == BinaryOpType::Sub);
dtype = lhs->dtype();
} else if (isPointerType(rhs->dtype())) {
TORCH_INTERNAL_ASSERT(isIntegralType(lhs->dtype()));
TORCH_INTERNAL_ASSERT(
op_type == BinaryOpType::Add || op_type == BinaryOpType::Sub);
dtype = DataType::SMemAddress;
dtype = rhs->dtype();
} else if (
(lhs->dtype() == DataType::Int && rhs->dtype() == DataType::Int32) ||
(lhs->dtype() == DataType::Int32 && rhs->dtype() == DataType::Int)) {
Expand Down
Loading

0 comments on commit 16a26a1

Please sign in to comment.