Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MatMul] Update generated code after memory index hoisting #1974

Open
wants to merge 20 commits into
base: predicate_shift
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 153 additions & 15 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,12 +504,60 @@ class CudaKernelGenerator : private OptOutConstDispatch {
return index.str();
}

// Generate the tensor index that are directly added
// to a base address pointer. So all the components
// are computed in units of bytes.
std::string genTensorAddressIndex(
const kir::TensorIndex* ti,
DataType dtype) {
bool first = true;
std::stringstream index;
for (auto* ind : ti->indices()) {
if (!ind->isZeroInt()) {
if (!first) {
index << " + ";
}

// Multiply all the components here by the size of the data
// type to get byte offset.
index << "(" << genInline(ind) << ")*" << dataTypeSize(dtype);
first = false;
}
}

// If there is a uniform component in this tensor index,
// just add them too.
// See also, [Double Buffer Uniform Offset].
if (ti->uniformAddress() != nullptr) {
if (!first) {
index << " + ";
}
index << genInline(ti->uniformAddress());
first = false;
}

if (first) {
index << "0";
}

return index.str();
}

void handle(const kir::TensorIndex* ti) final {
bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID();
if (is_volatile) {
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
}

if (ti->hasBaseAddress()) {
// WAR path to generate a tensor index with pointer content.
code_ << "reinterpret_cast<" << ti->view()->dtype() << "*>("
<< gen(ti->baseAddress()) << ")"
<< "[" << genTensorIndex(ti) << "]";
return;
}

code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]";
}

Expand Down Expand Up @@ -545,20 +593,41 @@ class CudaKernelGenerator : private OptOutConstDispatch {
return ss.str();
}

//! Generates the given value as a pointer address as
//! either:
//! 1. hosted_base_ptr + address_index
//! 2. &Tensor[index]
//! depending on if the given index value carries
//! a hoisted component or not.
std::string genMaybeHoistedPointer(const Val* val) {
auto ti = dynamic_cast<const kir::TensorIndex*>(val);
TORCH_INTERNAL_ASSERT(ti != nullptr, "only support tensor index input");
std::stringstream ss;

if (ti->hasBaseAddress()) {
ss << genTensorAddressIndex(ti, ti->view()->dtype()) << ","
<< gen(ti->baseAddress());
} else {
ss << "&" << gen(ti) << "\n";
}

return ss.str();
}

// Utility function to emit a cp.async intrinsic
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
auto dtype = ldst->in()->getDataType().value();

if (ldst->predicate() == nullptr) {
// Out of line predicate variant
indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ");\n";
} else {
// Inline predicate variant
indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ","
indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ","
<< genInline(ldst->predicate()) << ");\n";
}
}
Expand All @@ -579,8 +648,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
code_ << " (";
code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size)
<< ","
<< "&" << gen(ldst->in()) << ");\n";
<< "," << genMaybeHoistedPointer(ldst->in()) << ");\n";
genBankConflictCheck(ldst->in(), 16);
}

Expand Down Expand Up @@ -697,6 +765,25 @@ class CudaKernelGenerator : private OptOutConstDispatch {
!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) {
// Vectorized initialization
indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n";
} else if (
uop->out()->isA<kir::TensorIndex>() &&
uop->out()->as<kir::TensorIndex>()->useSmemAddress()) {
// A special resource string "smemReset" is used if
// the unary op writes to the shared memory using
// the lifted 32b shared mem pointer.
// This mode is reserved for resetting shared memory
// space at the moment currently.
auto ti = uop->out()->as<kir::TensorIndex>();
// Special case branch for smem reset
// FIXME: only support filling zero at the moment:
// could possibly extend.
TORCH_INTERNAL_ASSERT(
uop->in()->isZero(), "only support filling zero in smem reset");

indent() << "smemReset<" << ti->view()->dtype() << ","
<< vector_word_size << ">(" << gen(ti->baseAddress())
<< "+" << genTensorAddressIndex(ti, ti->view()->dtype())
<< ");\n";
} else {
// Note: currently arraySet option is not vectorized, so it will
// rely on auto vectorization pass of cuda compiler.
Expand Down Expand Up @@ -2607,7 +2694,11 @@ class CudaKernelGenerator : private OptOutConstDispatch {
alloc_map_.emplace(alloc->buffer(), alloc);

if (!alloc->buffer()->isA<TensorView>()) {
indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n";
indent() << buffer_dtype << " " << gen(alloc->buffer());
if (alloc->zeroInit()) {
code_ << " = 0";
}
code_ << ";\n";
return;
}

Expand Down Expand Up @@ -2684,12 +2775,59 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

void handle(const kir::AddressCompute* address_compute) final {
indent() << "// Address tensor for indexing "
<< varName(address_compute->dataTv()) << "\n";
indent() << gen(address_compute->addressTv()) << " = "
<< genTensorIndex(
address_compute->dataTv()->as<kir::TensorIndex>())
<< ";\n";
// FIXME:
// All the global/shared memory address/offset manipulations
// to reduce register usage are currently lumped into this single
// kernel IR operator.
//
// If there's any need to commit to the current codegen tweaks
// longer, could consider separating them into more IR nodes.
if (address_compute->opType() ==
kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) {
indent() << "doubleBufferSwitch<" << address_compute->stageNumber() << ","
<< address_compute->loopOffset() << ">("
<< gen(address_compute->doubleBufferSwitchIndex()) << ","
<< gen(address_compute->loopIndex()) << ","
<< gen(address_compute->doubleBufferByteSize()) << ");\n";
} else if (
address_compute->opType() ==
kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) {
indent() << "doubleBufferUpdate<" << address_compute->stageNumber() << ","
<< address_compute->loopOffset() << ">("
<< gen(address_compute->addressTv()) << ","
<< gen(address_compute->loopIndex()) << ","
<< gen(address_compute->doubleBufferByteSize()) << ");\n";
} else if (
address_compute->opType() ==
kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) {
indent() << gen(address_compute->addressTv()) << "+="
<< genTensorAddressIndex(
address_compute->incrementValue(),
address_compute->dataTv()->dtype())
<< ";\n";
} else if (
address_compute->opType() ==
kir::AddressCompute::AddressComputeOpType::GMEM_DECREMENT) {
indent() << gen(address_compute->addressTv()) << "-="
<< genTensorAddressIndex(
address_compute->incrementValue(),
address_compute->dataTv()->dtype())
<< ";\n";
} else {
indent() << "//Base Address:::\n";
indent() << gen(address_compute->addressTv());

if (address_compute->addressTv()->dtype() == DataType::Pointer) {
code_ << " = (DataPointer) &"
<< gen(address_compute->dataTv()->as<kir::TensorIndex>())
<< ";\n";
} else if (
address_compute->addressTv()->dtype() == DataType::SmemAddress) {
code_ << " = Turing::util::toSmem(&"
<< gen(address_compute->dataTv()->as<kir::TensorIndex>())
<< ");\n";
}
}
}

void handle(const kir::GridSync* sync) final {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ typedef int int32_t;
typedef unsigned int uint32_t;
typedef long long int int64_t;
typedef unsigned long long int uint64_t;
typedef char* DataPointer;
typedef unsigned SmemAddress;
)";
}

Expand Down
Loading